diff --git a/README.md b/README.md
index 0cad06b5de366188f3e251ef2ff54dc7e8bb2f87..219f3b3c19af202414aa8dbc0b6a885a05ffa7c7 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,37 @@ Activation is a python package that contains custom CUDA-based activation kernel
- Currently implemented
- [PolyNorm](https://arxiv.org/html/2411.03884v1)
- [RMSNorm](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
+ - **FusedAddRMSNorm**
+
+ A fused operator that combines **residual addition** (`x + residual`) with **RMSNorm** in a single kernel.
+ - Instead of:
+
+ ```python
+ y = x + residual
+ out = rms_norm(y, weight, eps)
+ ```
+
+ - Fused as:
+
+ ```python
+ out = fused_add_rms_norm(x, residual, weight, eps)
+ ```
+
+ - **FusedMulPolyNorm**
+
+ A fused operator that combines **PolyNorm** with an **element-wise multiplication** by a Tensor.
+ - Instead of:
+
+ ```python
+ y = poly_norm(x, weight, bias, eps)
+ out = y * a
+ ```
+
+ - Fused as:
+
+ ```python
+ out = fused_mul_poly_norm(x, a, weight, bias, eps)
+ ```
## Usage
@@ -28,18 +59,158 @@ print(poly_norm(x))
```
## Performance
+- Test cases are from the Motif LLM
+- The results can be reproduced using the provided benchmarking tools.
+- For details on how to use the benchmarking tools, please refer to the [benchmarks README](./benchmarks/README.md).
+- The benchmark results may show fluctuations, especially in the backward pass and when the dimension size is small.
+
+### RMSNorm
+
+#### H100 Results
+
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
+
+
+
+#### MI250 Results
+
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
+
+
+
+---
+
+### FusedAddRMSNorm
+
+> [!NOTE]
+> For fusion case performance, the **non-fused baseline** was implemented with our **custom kernels**.
+
+#### H100 Results
+
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
+
+
+
+#### MI250 Results
+
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
+
+
+
+---
### PolyNorm
-- Test cases are from the Motif LLM
-- You can reproduce the results with:
+#### H100 Results
-```bash
-cd tests
-pytest --run-perf --do-plot
-```
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
+
+
+
+#### MI250 Results
+
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
+
+
+
+---
+
+### FusedMulPolyNorm
+
+> [!NOTE]
+> For fusion case performance, the **non-fused baseline** was implemented with our **custom kernels**.
+
+#### H100 Results
+
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
+
+
+
+#### MI250 Results
+
+
+Forward Performance
+
+
+
+
+
+
+Backward Performance
+
+
-
+
## Pre-commit Hooks
diff --git a/activation/block_reduce.h b/activation/block_reduce.h
deleted file mode 100644
index 61c56e3b4f71646ddfeb5f879bc8169273c285f8..0000000000000000000000000000000000000000
--- a/activation/block_reduce.h
+++ /dev/null
@@ -1,21 +0,0 @@
-namespace motif {
-
-template
-__device__ acc_t _block_reduce_sum(acc_t *shared, const float val,
- const int d) {
- // TODO: Optimize with warp-level primitives
- __syncthreads();
-
- shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
- __syncthreads();
- for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
- if (threadIdx.x < stride) {
- shared[threadIdx.x] += shared[threadIdx.x + stride];
- }
- __syncthreads();
- }
-
- return shared[0];
-}
-
-} // namespace motif
diff --git a/activation/fused_add_rms_norm.cu b/activation/fused_add_rms_norm.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9e73bd8f61fbd1ad59824398cf6b43e121071ea8
--- /dev/null
+++ b/activation/fused_add_rms_norm.cu
@@ -0,0 +1,157 @@
+#include
+#include
+#include
+#include
+
+#include
+
+#include "assert_utils.h"
+#include "atomic_utils.h"
+#include "cuda_compat.h"
+#include "dispatch_utils.h"
+
+namespace motif {
+
+template struct alignas(sizeof(type) * N) type_vec_t {
+ type data[N];
+};
+
+template
+__global__ std::enable_if_t<(width > 0)>
+fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
+ scalar_t *__restrict__ add_out, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ residual, // [..., d]
+ const scalar_t *__restrict__ weight, // [d]
+ const float eps, const int d) {
+ using vec_t = type_vec_t;
+
+ const int vec_d = d / width;
+ const int64_t vec_offset = blockIdx.x * vec_d;
+ const vec_t *__restrict__ input_vec = reinterpret_cast(input);
+ const vec_t *__restrict__ residual_vec =
+ reinterpret_cast(residual);
+ vec_t *__restrict__ add_out_vec = reinterpret_cast(add_out);
+ acc_t sum_square = 0.0f;
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + idx];
+ vec_t res_vec = residual_vec[vec_offset + idx];
+ vec_t add_vec;
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x = x_vec.data[i] + res_vec.data[i];
+ sum_square += x * x;
+ add_vec.data[i] = x;
+ }
+ add_out_vec[vec_offset + idx] = add_vec;
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
+
+ __shared__ acc_t s_scale;
+
+ if (threadIdx.x == 0) {
+ s_scale = rsqrtf(sum_square / d + eps);
+ }
+ __syncthreads();
+
+ const vec_t *__restrict__ weight_vec =
+ reinterpret_cast(weight);
+ vec_t *__restrict__ output_vec = reinterpret_cast(out);
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = add_out_vec[vec_offset + idx];
+ vec_t w_vec = weight_vec[idx];
+ vec_t y_vec;
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x = x_vec.data[i];
+ acc_t w = w_vec.data[i];
+
+ y_vec.data[i] = w * x * s_scale;
+ }
+ output_vec[vec_offset + idx] = y_vec;
+ }
+}
+
+template
+__global__ std::enable_if_t<(width == 0)>
+fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
+ scalar_t *__restrict__ add_out, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ residual, // [..., d]
+ const scalar_t *__restrict__ weight, // [d]
+ const float eps, const int d) {
+ const int64_t token_idx = blockIdx.x;
+ const int64_t vec_idx = threadIdx.x;
+ acc_t sum_square = 0.0f;
+
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
+ acc_t x = input[token_idx * d + idx] + residual[token_idx * d + idx];
+ sum_square += x * x;
+ add_out[token_idx * d + idx] = x;
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
+
+ __shared__ acc_t s_scale;
+
+ if (vec_idx == 0) {
+ s_scale = rsqrtf(sum_square / d + eps);
+ }
+ __syncthreads();
+
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
+ acc_t x = add_out[token_idx * d + idx];
+ acc_t w = weight[idx];
+ out[token_idx * d + idx] = w * x * s_scale;
+ }
+}
+
+} // namespace motif
+
+#define LAUNCH_RMS_NORM(width) \
+ MOTIF_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
+ motif::fused_add_rms_norm_kernel \
+ <<>>( \
+ out.data_ptr(), add_out.data_ptr(), \
+ input.data_ptr(), residual.data_ptr(), \
+ weight.data_ptr(), eps, d); \
+ });
+
+void fused_add_rms_norm(torch::Tensor &out, // [..., d]
+ torch::Tensor &add_out, // [..., d]
+ const torch::Tensor &input, // [..., d]
+ const torch::Tensor &residual, // [..., d]
+ const torch::Tensor &weight, // [d]
+ double eps) {
+ AssertTensorShapeEqual(input, residual, "input", "residual");
+ AssertTensorShapeEqual(input, out, "input", "out");
+ AssertTensorShapeEqual(input, add_out, "input", "result");
+ AssertTensorNotNull(weight, "weight");
+ // TODO shape check
+
+ int d = input.size(-1);
+ int64_t num_tokens = input.numel() / input.size(-1);
+ dim3 grid(num_tokens);
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
+ dim3 block(std::min(d, max_block_size));
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ if (d % 8 == 0) {
+ LAUNCH_RMS_NORM(8);
+ } else {
+ LAUNCH_RMS_NORM(0);
+ }
+}
diff --git a/activation/fused_mul_poly_norm.cu b/activation/fused_mul_poly_norm.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ce35c1dd67be0fd4c41658b55fc433bd87d1038a
--- /dev/null
+++ b/activation/fused_mul_poly_norm.cu
@@ -0,0 +1,642 @@
+#include
+#include
+#include
+#include
+
+#include
+
+#include "assert_utils.h"
+#include "atomic_utils.h"
+#include "cuda_compat.h"
+#include "dispatch_utils.h"
+
+namespace motif {
+
+template struct alignas(sizeof(type) * N) type_vec_t {
+ type data[N];
+};
+
+struct SumOp {
+ __device__ float3 operator()(const float3 &a, const float3 &b) const {
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
+ }
+};
+
+struct SumOp4 {
+ __device__ float4 operator()(const float4 &a, const float4 &b) const {
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+ }
+};
+
+template
+__global__ std::enable_if_t<(width > 0)>
+fused_mul_poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ mul, // [..., d]
+ const scalar_t *__restrict__ weight, // [3]
+ const scalar_t *__restrict__ bias, // [1]
+ const float eps, const int d) {
+ using vec_t = type_vec_t;
+
+ const int vec_d = d / width;
+ const int64_t vec_offset = blockIdx.x * vec_d;
+ const vec_t *__restrict__ input_vec = reinterpret_cast(input);
+
+ acc_t sum2 = 0.0f;
+ acc_t sum4 = 0.0f;
+ acc_t sum6 = 0.0f;
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + idx];
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x1 = x_vec.data[i];
+ acc_t x2 = x1 * x1;
+ acc_t x4 = x2 * x2;
+ acc_t x6 = x4 * x2;
+
+ sum2 += x2;
+ sum4 += x4;
+ sum6 += x6;
+ }
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
+
+ __shared__ acc_t s_bias;
+
+ __shared__ acc_t s_w2_inv_std1;
+ __shared__ acc_t s_w1_inv_std2;
+ __shared__ acc_t s_w0_inv_std3;
+
+ if (threadIdx.x == 0) {
+ acc_t w0 = weight[0];
+ acc_t w1 = weight[1];
+ acc_t w2 = weight[2];
+ s_bias = bias[0];
+
+ s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
+ s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
+ s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
+ }
+ __syncthreads();
+
+ acc_t w2_inv_std1 = s_w2_inv_std1;
+ acc_t w1_inv_std2 = s_w1_inv_std2;
+ acc_t w0_inv_std3 = s_w0_inv_std3;
+ acc_t bias_reg = s_bias;
+
+ vec_t *__restrict__ output_vec = reinterpret_cast(out);
+ const vec_t *__restrict__ mul_vec = reinterpret_cast(mul);
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + idx];
+ vec_t m_vec = mul_vec[vec_offset + idx];
+ vec_t y_vec;
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x1 = x_vec.data[i];
+ scalar_t m = m_vec.data[i];
+ acc_t x2 = x1 * x1;
+ acc_t x3 = x2 * x1;
+ scalar_t poly_norm_result =
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
+ y_vec.data[i] = poly_norm_result * m;
+ }
+ output_vec[vec_offset + idx] = y_vec;
+ }
+}
+
+template
+__global__ std::enable_if_t<(width == 0)>
+fused_mul_poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ mul, // [..., d]
+ const scalar_t *__restrict__ weight, // [3]
+ const scalar_t *__restrict__ bias, // [1]
+ const float eps, const int d) {
+ const int64_t token_idx = blockIdx.x;
+
+ acc_t sum2 = 0.0f;
+ acc_t sum4 = 0.0f;
+ acc_t sum6 = 0.0f;
+
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ acc_t x1 = input[token_idx * d + idx];
+ acc_t x2 = x1 * x1;
+ acc_t x4 = x2 * x2;
+ acc_t x6 = x4 * x2;
+
+ sum2 += x2;
+ sum4 += x4;
+ sum6 += x6;
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
+
+ __shared__ acc_t s_bias;
+
+ __shared__ acc_t s_w2_inv_std1;
+ __shared__ acc_t s_w1_inv_std2;
+ __shared__ acc_t s_w0_inv_std3;
+
+ if (threadIdx.x == 0) {
+ acc_t w0 = weight[0];
+ acc_t w1 = weight[1];
+ acc_t w2 = weight[2];
+ s_bias = bias[0];
+
+ s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
+ s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
+ s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
+ }
+ __syncthreads();
+
+ acc_t w2_inv_std1 = s_w2_inv_std1;
+ acc_t w1_inv_std2 = s_w1_inv_std2;
+ acc_t w0_inv_std3 = s_w0_inv_std3;
+ acc_t bias_reg = s_bias;
+
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ acc_t x1 = input[token_idx * d + idx];
+ scalar_t m = mul[token_idx * d + idx];
+ acc_t x2 = x1 * x1;
+ acc_t x3 = x2 * x1;
+ scalar_t poly_norm_result =
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
+ out[token_idx * d + idx] = poly_norm_result * m;
+ }
+}
+
+template
+__global__ std::enable_if_t<(width > 0)> fused_mul_poly_norm_backward_kernel(
+ scalar_t *__restrict__ input_grad, // [..., d]
+ scalar_t *__restrict__ mul_grad, // [..., d]
+ acc_t *__restrict__ temp_weight_grad, // [..., 3]
+ acc_t *__restrict__ temp_bias_grad, // [..., 1]
+ const scalar_t *__restrict__ output_grad, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ mul, // [..., d]
+ const scalar_t *__restrict__ weight, // [3]
+ const scalar_t *__restrict__ bias, // [1]
+ const float eps, const int d) {
+ using vec_t = type_vec_t;
+
+ const int vec_d = d / width;
+ const int64_t vec_offset = blockIdx.x * vec_d;
+ const vec_t *__restrict__ input_vec = reinterpret_cast(input);
+ const vec_t *__restrict__ mul_vec = reinterpret_cast(mul);
+ const vec_t *__restrict__ output_grad_vec =
+ reinterpret_cast(output_grad);
+
+ acc_t sum2 = 0.0f;
+ acc_t sum4 = 0.0f;
+ acc_t sum6 = 0.0f;
+
+ acc_t sum_dx1 = 0.0f;
+ acc_t sum_dx2 = 0.0f;
+ acc_t sum_dx3 = 0.0f;
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + idx];
+ vec_t dy_fused_vec = output_grad_vec[vec_offset + idx];
+ vec_t m_vec = mul_vec[vec_offset + idx];
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x1 = x_vec.data[i];
+ acc_t x2 = x1 * x1;
+ acc_t x3 = x2 * x1;
+ acc_t x4 = x2 * x2;
+ acc_t x6 = x3 * x3;
+
+ sum2 += x2;
+ sum4 += x4;
+ sum6 += x6;
+
+ acc_t dy = dy_fused_vec.data[i] * m_vec.data[i];
+
+ sum_dx1 += dy * x1;
+ sum_dx2 += dy * x2;
+ sum_dx3 += dy * x3;
+ }
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
+
+ float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
+ __syncthreads();
+ float3 block_sum_dxs =
+ BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
+
+ sum_dx1 = block_sum_dxs.x;
+ sum_dx2 = block_sum_dxs.y;
+ sum_dx3 = block_sum_dxs.z;
+
+ __shared__ acc_t s_mean2;
+ __shared__ acc_t s_mean4;
+ __shared__ acc_t s_mean6;
+ __shared__ acc_t s_sdx1;
+ __shared__ acc_t s_sdx2;
+ __shared__ acc_t s_sdx3;
+
+ const acc_t inv_d = acc_t(1) / d;
+
+ if (threadIdx.x == 0) {
+ s_mean2 = sum2 * inv_d + eps;
+ s_mean4 = sum4 * inv_d + eps;
+ s_mean6 = sum6 * inv_d + eps;
+
+ s_sdx1 = sum_dx1 * inv_d;
+ s_sdx2 = sum_dx2 * inv_d;
+ s_sdx3 = sum_dx3 * inv_d;
+ }
+ __syncthreads();
+
+ acc_t w0 = weight[0];
+ acc_t w1 = weight[1];
+ acc_t w2 = weight[2];
+ acc_t bias_reg = bias[0];
+
+ acc_t mean2 = s_mean2;
+ acc_t mean4 = s_mean4;
+ acc_t mean6 = s_mean6;
+ acc_t sdx1 = s_sdx1;
+ acc_t sdx2 = s_sdx2;
+ acc_t sdx3 = s_sdx3;
+
+ acc_t inv_std1 = rsqrtf(mean2);
+ acc_t inv_std2 = rsqrtf(mean4);
+ acc_t inv_std3 = rsqrtf(mean6);
+
+ acc_t w2_inv_std1 = inv_std1 * w2;
+ acc_t w1_inv_std2 = inv_std2 * w1;
+ acc_t w0_inv_std3 = inv_std3 * w0;
+
+ // inv_std / mean == powf(mean, -1.5)
+ acc_t c1 = w2_inv_std1 / mean2;
+ acc_t c2 = acc_t(2) * w1_inv_std2 / mean4;
+ acc_t c3 = acc_t(3) * w0_inv_std3 / mean6;
+
+ acc_t sum_dy = 0;
+ acc_t sum_dw0 = 0;
+ acc_t sum_dw1 = 0;
+ acc_t sum_dw2 = 0;
+
+ vec_t *__restrict__ input_grad_vec = reinterpret_cast(input_grad);
+ vec_t *__restrict__ mul_grad_vec = reinterpret_cast(mul_grad);
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + idx];
+ vec_t dy_fused_vec = output_grad_vec[vec_offset + idx];
+ vec_t m_vec = mul_vec[vec_offset + idx];
+ vec_t dx_vec;
+ vec_t dm_vec;
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x1 = x_vec.data[i];
+ acc_t x2 = x1 * x1;
+ acc_t x3 = x2 * x1;
+ acc_t dy = dy_fused_vec.data[i] * m_vec.data[i];
+
+ // For register optimization, the order of the following logic matters.
+ // The input_grad related logic must be placed at the very end.
+ sum_dy += dy;
+ sum_dw0 += dy * (x3 * inv_std3);
+ sum_dw1 += dy * (x2 * inv_std2);
+ sum_dw2 += dy * (x1 * inv_std1);
+
+ if (mul_grad) {
+ scalar_t poly_norm_result =
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
+ dm_vec.data[i] = poly_norm_result * dy_fused_vec.data[i];
+ }
+
+ if (input_grad) {
+ acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
+ acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
+ acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
+ dx_vec.data[i] = dx1 + dx2 + dx3;
+ }
+ }
+
+ if (input_grad) {
+ input_grad_vec[vec_offset + idx] = dx_vec;
+ }
+ if (mul_grad) {
+ mul_grad_vec[vec_offset + idx] = dm_vec;
+ }
+ }
+
+ using BlockReduce4 = cub::BlockReduce;
+ __shared__ typename BlockReduce4::TempStorage reduceStore4;
+
+ float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
+ float4 block_sum_ds =
+ BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
+
+ sum_dy = block_sum_ds.x;
+ sum_dw0 = block_sum_ds.y;
+ sum_dw1 = block_sum_ds.z;
+ sum_dw2 = block_sum_ds.w;
+
+ if (threadIdx.x == 0) {
+ temp_bias_grad[blockIdx.x] = sum_dy;
+ temp_weight_grad[blockIdx.x * 3 + 0] = sum_dw0;
+ temp_weight_grad[blockIdx.x * 3 + 1] = sum_dw1;
+ temp_weight_grad[blockIdx.x * 3 + 2] = sum_dw2;
+ }
+}
+
+template
+__global__ std::enable_if_t<(width == 0)> fused_mul_poly_norm_backward_kernel(
+ scalar_t *__restrict__ input_grad, // [..., d]
+ scalar_t *__restrict__ mul_grad, // [..., d]
+ acc_t *__restrict__ temp_weight_grad, // [..., 3]
+ acc_t *__restrict__ temp_bias_grad, // [..., 1]
+ const scalar_t *__restrict__ output_grad, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ mul, // [..., d]
+ const scalar_t *__restrict__ weight, // [3]
+ const scalar_t *__restrict__ bias, // [1]
+ const float eps, const int d) {
+ const int64_t token_idx = blockIdx.x;
+
+ acc_t sum2 = 0.0f;
+ acc_t sum4 = 0.0f;
+ acc_t sum6 = 0.0f;
+
+ acc_t sum_dx1 = 0.0f;
+ acc_t sum_dx2 = 0.0f;
+ acc_t sum_dx3 = 0.0f;
+
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ acc_t dy = output_grad[token_idx * d + idx] * mul[token_idx * d + idx];
+
+ acc_t x1 = input[token_idx * d + idx];
+ acc_t x2 = x1 * x1;
+ acc_t x3 = x2 * x1;
+ acc_t x4 = x2 * x2;
+ acc_t x6 = x3 * x3;
+
+ sum2 += x2;
+ sum4 += x4;
+ sum6 += x6;
+
+ sum_dx1 += dy * x1;
+ sum_dx2 += dy * x2;
+ sum_dx3 += dy * x3;
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
+
+ float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
+ __syncthreads();
+ float3 block_sum_dxs =
+ BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
+
+ sum_dx1 = block_sum_dxs.x;
+ sum_dx2 = block_sum_dxs.y;
+ sum_dx3 = block_sum_dxs.z;
+
+ __shared__ acc_t s_mean2;
+ __shared__ acc_t s_mean4;
+ __shared__ acc_t s_mean6;
+ __shared__ acc_t s_sdx1;
+ __shared__ acc_t s_sdx2;
+ __shared__ acc_t s_sdx3;
+
+ const acc_t inv_d = acc_t(1) / d;
+
+ if (threadIdx.x == 0) {
+ s_mean2 = sum2 * inv_d + eps;
+ s_mean4 = sum4 * inv_d + eps;
+ s_mean6 = sum6 * inv_d + eps;
+
+ s_sdx1 = sum_dx1 * inv_d;
+ s_sdx2 = sum_dx2 * inv_d;
+ s_sdx3 = sum_dx3 * inv_d;
+ }
+ __syncthreads();
+
+ acc_t w0 = weight[0];
+ acc_t w1 = weight[1];
+ acc_t w2 = weight[2];
+ acc_t bias_reg = bias[0];
+
+ acc_t mean2 = s_mean2;
+ acc_t mean4 = s_mean4;
+ acc_t mean6 = s_mean6;
+ acc_t sdx1 = s_sdx1;
+ acc_t sdx2 = s_sdx2;
+ acc_t sdx3 = s_sdx3;
+
+ acc_t inv_std1 = rsqrtf(mean2);
+ acc_t inv_std2 = rsqrtf(mean4);
+ acc_t inv_std3 = rsqrtf(mean6);
+
+ acc_t w2_inv_std1 = inv_std1 * w2;
+ acc_t w1_inv_std2 = inv_std2 * w1;
+ acc_t w0_inv_std3 = inv_std3 * w0;
+
+ // inv_std / mean == powf(mean, -1.5)
+ acc_t c1 = w2_inv_std1 / mean2;
+ acc_t c2 = acc_t(2) * w1_inv_std2 / mean4;
+ acc_t c3 = acc_t(3) * w0_inv_std3 / mean6;
+
+ acc_t sum_dy = 0;
+ acc_t sum_dw0 = 0;
+ acc_t sum_dw1 = 0;
+ acc_t sum_dw2 = 0;
+
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ scalar_t dy_fused = output_grad[token_idx * d + idx];
+ acc_t dy = dy_fused * mul[token_idx * d + idx];
+ acc_t x1 = input[token_idx * d + idx];
+ acc_t x2 = x1 * x1;
+ acc_t x3 = x2 * x1;
+
+ if (input_grad) {
+ acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
+ acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
+ acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
+ input_grad[token_idx * d + idx] = dx1 + dx2 + dx3;
+ }
+
+ if (mul_grad) {
+ scalar_t poly_norm_result =
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
+ mul_grad[token_idx * d + idx] = poly_norm_result * dy_fused;
+ }
+
+ sum_dy += dy;
+ sum_dw0 += dy * (x3 * inv_std3);
+ sum_dw1 += dy * (x2 * inv_std2);
+ sum_dw2 += dy * (x1 * inv_std1);
+ }
+
+ using BlockReduce4 = cub::BlockReduce;
+ __shared__ typename BlockReduce4::TempStorage reduceStore4;
+
+ float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
+ float4 block_sum_ds =
+ BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
+
+ sum_dy = block_sum_ds.x;
+ sum_dw0 = block_sum_ds.y;
+ sum_dw1 = block_sum_ds.z;
+ sum_dw2 = block_sum_ds.w;
+
+ if (threadIdx.x == 0) {
+ temp_bias_grad[token_idx] = sum_dy;
+ temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
+ temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
+ temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
+ }
+}
+
+} // namespace motif
+
+#define LAUNCH_FUSED_MUL_POLY_NORM(width) \
+ MOTIF_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), "fused_mul_poly_norm_kernel", [&] { \
+ motif::fused_mul_poly_norm_kernel \
+ <<>>( \
+ out.data_ptr(), input.data_ptr(), \
+ mul.data_ptr(), weight.data_ptr(), \
+ bias.data_ptr(), eps, d); \
+ });
+
+void fused_mul_poly_norm(torch::Tensor &out, // [..., d]
+ const torch::Tensor &input, // [..., d]
+ const torch::Tensor &mul, // [..., d]
+ const torch::Tensor &weight, // [3]
+ const torch::Tensor &bias, // [1]
+ double eps) {
+ AssertTensorShapeEqual(input, out, "input", "out");
+ AssertTensorShapeEqual(input, mul, "input", "mul");
+ AssertTensorNotNull(weight, "weight");
+ AssertTensorNotNull(bias, "bias");
+ // TODO shape check
+
+ int d = input.size(-1);
+ int64_t num_tokens = input.numel() / d;
+ dim3 grid(num_tokens);
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
+ dim3 block(std::min(d, max_block_size));
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ if (d % 8 == 0) {
+ LAUNCH_FUSED_MUL_POLY_NORM(8);
+ } else {
+ LAUNCH_FUSED_MUL_POLY_NORM(0);
+ }
+}
+
+#define LAUNCH_POLY_NORM_BACKWARD(width) \
+ MOTIF_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), "fused_mul_poly_norm_backward_kernel", [&] { \
+ motif::fused_mul_poly_norm_backward_kernel \
+ <<>>( \
+ input_grad.data_ptr(), \
+ mul_grad.data_ptr(), \
+ temp_weight_grad.data_ptr(), \
+ temp_bias_grad.data_ptr(), \
+ output_grad.data_ptr(), input.data_ptr(), \
+ mul.data_ptr(), weight.data_ptr(), \
+ bias.data_ptr(), eps, d); \
+ });
+
+void fused_mul_poly_norm_backward(torch::Tensor &input_grad, // [..., d]
+ torch::Tensor &mul_grad, // [..., d]
+ torch::Tensor &weight_grad, // [3]
+ torch::Tensor &bias_grad, // [1]
+ const torch::Tensor &output_grad, // [..., d]
+ const torch::Tensor &input, // [..., d]
+ const torch::Tensor &mul, // [..., d]
+ const torch::Tensor &weight, // [3]
+ const torch::Tensor &bias, // [1]
+ double eps) {
+ AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
+ AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
+ AssertTensorShapeEqual(input, mul_grad, "input", "mul_grad");
+ AssertTensorShapeEqual(input, mul, "input", "mul");
+ AssertTensorNotNull(weight, "weight");
+ // TODO shape check
+ // weight_grad, bias_grad, mul_grad and input_grad can be nullable
+
+ int d = input.size(-1);
+ int64_t num_tokens = input.numel() / d;
+ dim3 grid(num_tokens);
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
+ dim3 block(std::min(d, max_block_size));
+
+ torch::Tensor temp_weight_grad =
+ torch::empty({num_tokens, 3}, input.options().dtype(torch::kFloat));
+ torch::Tensor temp_bias_grad =
+ torch::empty({num_tokens, 1}, output_grad.options().dtype(torch::kFloat));
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ if (d % 8 == 0 && input.element_size() == 2) {
+ LAUNCH_POLY_NORM_BACKWARD(8);
+ } else if (d % 4 == 0 && input.element_size() == 4) {
+ LAUNCH_POLY_NORM_BACKWARD(4);
+ } else {
+ LAUNCH_POLY_NORM_BACKWARD(0);
+ }
+
+ if (bias_grad.defined()) {
+ torch::Tensor acc = torch::empty_like(bias_grad, temp_bias_grad.options());
+ at::sum_out(acc, temp_bias_grad, {0});
+ bias_grad.copy_(acc);
+ }
+
+ if (weight_grad.defined()) {
+ torch::Tensor acc =
+ torch::empty_like(weight_grad, temp_weight_grad.options());
+ at::sum_out(acc, temp_weight_grad, {0});
+ weight_grad.copy_(acc);
+ }
+}
diff --git a/activation/poly_norm.cu b/activation/poly_norm.cu
index a1dcdec4accc3d14602a0fe32701135b24190f74..1977ef7260a66d0a9f66b0cca6aab32eca9d7c67 100644
--- a/activation/poly_norm.cu
+++ b/activation/poly_norm.cu
@@ -7,7 +7,6 @@
#include "assert_utils.h"
#include "atomic_utils.h"
-#include "block_reduce.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
@@ -17,6 +16,18 @@ template struct alignas(sizeof(type) * N) type_vec_t {
type data[N];
};
+struct SumOp {
+ __device__ float3 operator()(const float3 &a, const float3 &b) const {
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
+ }
+};
+
+struct SumOp4 {
+ __device__ float4 operator()(const float4 &a, const float4 &b) const {
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
+ }
+};
+
template
__global__ std::enable_if_t<(width > 0)>
poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
@@ -39,7 +50,7 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
#pragma unroll
for (int i = 0; i < width; ++i) {
- acc_t x1 = static_cast(x_vec.data[i]);
+ acc_t x1 = x_vec.data[i];
acc_t x2 = x1 * x1;
acc_t x4 = x2 * x2;
acc_t x6 = x4 * x2;
@@ -50,14 +61,16 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
}
}
- using BlockReduce = cub::BlockReduce;
+ using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage reduceStore;
- sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
- __syncthreads();
- sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
- __syncthreads();
- sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
__shared__ acc_t s_bias;
@@ -90,14 +103,12 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
#pragma unroll
for (int i = 0; i < width; ++i) {
- acc_t x1 = static_cast(x_vec.data[i]);
+ acc_t x1 = x_vec.data[i];
acc_t x2 = x1 * x1;
acc_t x3 = x2 * x1;
- acc_t y =
+ y_vec.data[i] =
x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
-
- y_vec.data[i] = static_cast(y);
}
output_vec[vec_offset + idx] = y_vec;
}
@@ -127,14 +138,16 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
sum6 += x6;
}
- using BlockReduce = cub::BlockReduce;
+ using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage reduceStore;
- sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
- __syncthreads();
- sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
- __syncthreads();
- sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
__shared__ acc_t s_bias;
@@ -199,7 +212,7 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
#pragma unroll
for (int i = 0; i < width; ++i) {
- acc_t x1 = static_cast(x_vec.data[i]);
+ acc_t x1 = x_vec.data[i];
acc_t x2 = x1 * x1;
acc_t x3 = x2 * x1;
acc_t x4 = x2 * x2;
@@ -209,7 +222,7 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
sum4 += x4;
sum6 += x6;
- acc_t dy = static_cast(dy_vec.data[i]);
+ acc_t dy = dy_vec.data[i];
sum_dx1 += dy * x1;
sum_dx2 += dy * x2;
@@ -217,22 +230,25 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
}
}
- using BlockReduce = cub::BlockReduce;
+ using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage reduceStore;
- __syncthreads();
- sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
- __syncthreads();
- sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
- __syncthreads();
- sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
+
+ float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
__syncthreads();
- sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x);
- __syncthreads();
- sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x);
- __syncthreads();
- sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x);
+ float3 block_sum_dxs =
+ BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
+
+ sum_dx1 = block_sum_dxs.x;
+ sum_dx2 = block_sum_dxs.y;
+ sum_dx3 = block_sum_dxs.z;
__shared__ acc_t s_mean2;
__shared__ acc_t s_mean4;
@@ -288,16 +304,16 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
#pragma unroll
for (int i = 0; i < width; ++i) {
- acc_t x1 = static_cast(x_vec.data[i]);
+ acc_t x1 = x_vec.data[i];
acc_t x2 = x1 * x1;
acc_t x3 = x2 * x1;
- acc_t dy = static_cast(dy_vec.data[i]);
+ acc_t dy = dy_vec.data[i];
if (input_grad) {
acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
- dx_vec.data[i] = static_cast(dx1 + dx2 + dx3);
+ dx_vec.data[i] = dx1 + dx2 + dx3;
}
sum_dy += dy;
@@ -311,13 +327,17 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
}
}
- sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x);
- __syncthreads();
- sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x);
- __syncthreads();
- sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x);
- __syncthreads();
- sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x);
+ using BlockReduce4 = cub::BlockReduce;
+ __shared__ typename BlockReduce4::TempStorage reduceStore4;
+
+ float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
+ float4 block_sum_ds =
+ BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
+
+ sum_dy = block_sum_ds.x;
+ sum_dw0 = block_sum_ds.y;
+ sum_dw1 = block_sum_ds.z;
+ sum_dw2 = block_sum_ds.w;
if (threadIdx.x == 0) {
temp_bias_grad[blockIdx.x] = sum_dy;
@@ -364,22 +384,25 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
sum_dx3 += dy * x3;
}
- using BlockReduce = cub::BlockReduce;
+ using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage reduceStore;
- __syncthreads();
- sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
- __syncthreads();
- sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
- __syncthreads();
- sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
+ float3 thread_sums = make_float3(sum2, sum4, sum6);
+ float3 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+ sum2 = block_sums.x;
+ sum4 = block_sums.y;
+ sum6 = block_sums.z;
+
+ float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3);
__syncthreads();
- sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x);
- __syncthreads();
- sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x);
- __syncthreads();
- sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x);
+ float3 block_sum_dxs =
+ BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x);
+
+ sum_dx1 = block_sum_dxs.x;
+ sum_dx2 = block_sum_dxs.y;
+ sum_dx3 = block_sum_dxs.z;
__shared__ acc_t s_mean2;
__shared__ acc_t s_mean4;
@@ -445,13 +468,17 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
sum_dw2 += dy * (x1 * inv_std1);
}
- sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x);
- __syncthreads();
- sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x);
- __syncthreads();
- sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x);
- __syncthreads();
- sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x);
+ using BlockReduce4 = cub::BlockReduce;
+ __shared__ typename BlockReduce4::TempStorage reduceStore4;
+
+ float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2);
+ float4 block_sum_ds =
+ BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x);
+
+ sum_dy = block_sum_ds.x;
+ sum_dw0 = block_sum_ds.y;
+ sum_dw1 = block_sum_ds.z;
+ sum_dw2 = block_sum_ds.w;
if (threadIdx.x == 0) {
temp_bias_grad[token_idx] = sum_dy;
diff --git a/activation/rms_norm.cu b/activation/rms_norm.cu
index 9a2ffd0ac322acad34e56a1b2d53fde6178b714a..ee91d4dd8b2ae2844df52cab7bc57b054a0a97dd 100644
--- a/activation/rms_norm.cu
+++ b/activation/rms_norm.cu
@@ -7,18 +7,76 @@
#include "assert_utils.h"
#include "atomic_utils.h"
-#include "block_reduce.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace motif {
-template
-__global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
- const scalar_t *__restrict__ input, // [..., d]
- const scalar_t *__restrict__ weight, // [d]
- const float eps, const int d) {
+template struct alignas(sizeof(type) * N) type_vec_t {
+ type data[N];
+};
+template
+__global__ std::enable_if_t<(width > 0)>
+rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ weight, // [d]
+ const float eps, const int d) {
+ using vec_t = type_vec_t;
+
+ const int vec_d = d / width;
+ const int64_t vec_offset = blockIdx.x * vec_d;
+ const vec_t *__restrict__ input_vec = reinterpret_cast(input);
+ acc_t sum_square = 0.0f;
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + idx];
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x = x_vec.data[i];
+ sum_square += x * x;
+ }
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
+
+ __shared__ acc_t s_scale;
+
+ if (threadIdx.x == 0) {
+ s_scale = rsqrtf(sum_square / d + eps);
+ }
+ __syncthreads();
+
+ const vec_t *__restrict__ weight_vec =
+ reinterpret_cast(weight);
+ vec_t *__restrict__ output_vec = reinterpret_cast(out);
+
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + idx];
+ vec_t w_vec = weight_vec[idx];
+ vec_t y_vec;
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x = x_vec.data[i];
+ acc_t w = w_vec.data[i];
+
+ y_vec.data[i] = w * x * s_scale;
+ }
+ output_vec[vec_offset + idx] = y_vec;
+ }
+}
+
+template
+__global__ std::enable_if_t<(width == 0)>
+rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ weight, // [d]
+ const float eps, const int d) {
const int64_t token_idx = blockIdx.x;
const int64_t vec_idx = threadIdx.x;
acc_t sum_square = 0.0f;
@@ -28,20 +86,123 @@ __global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
sum_square += x * x;
}
- __shared__ acc_t shared[BLOCK_SIZE];
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+
+ sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x);
+
+ __shared__ acc_t s_scale;
+
+ if (vec_idx == 0) {
+ s_scale = rsqrtf(sum_square / d + eps);
+ }
+ __syncthreads();
- acc_t variance =
- _block_reduce_sum(shared, sum_square, d) / d;
- acc_t scale = rsqrt(variance + eps);
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
acc_t x = input[token_idx * d + idx];
acc_t w = weight[idx];
- out[token_idx * d + idx] = w * x * scale;
+ out[token_idx * d + idx] = w * x * s_scale;
}
}
+template
+__global__ std::enable_if_t<(width > 0)>
+rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
+ acc_t *__restrict__ temp_weight_grad, // [..., d]
+ const scalar_t *__restrict__ output_grad, // [..., d]
+ const scalar_t *__restrict__ input, // [..., d]
+ const scalar_t *__restrict__ weight, // [d]
+ const float eps, const int d) {
+ using vec_t = type_vec_t;
+ using dw_vec_t = type_vec_t;
+
+ const int64_t token_idx = blockIdx.x;
+ const int64_t vec_idx = threadIdx.x;
+
+ const int vec_d = d / width;
+ const int64_t vec_offset = token_idx * vec_d;
+
+ const vec_t *__restrict__ input_vec = reinterpret_cast(input);
+ const vec_t *__restrict__ output_grad_vec =
+ reinterpret_cast(output_grad);
+ const vec_t *__restrict__ weight_vec =
+ reinterpret_cast(weight);
+
+ acc_t d_sum = 0.0f;
+ acc_t sum_square = 0.0f;
+
+ for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + vidx];
+ vec_t dy_vec = output_grad_vec[vec_offset + vidx];
+ vec_t w_vec = weight_vec[vidx];
-template
-__global__ void
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x = x_vec.data[i];
+ acc_t dy = dy_vec.data[i];
+ acc_t w = w_vec.data[i];
+ d_sum += dy * x * w;
+ sum_square += x * x;
+ }
+ }
+
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+ struct SumOp {
+ __device__ float2 operator()(const float2 &a, const float2 &b) const {
+ return make_float2(a.x + b.x, a.y + b.y);
+ }
+ };
+ float2 thread_sums = make_float2(d_sum, sum_square);
+ float2 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
+
+ d_sum = block_sums.x;
+ sum_square = block_sums.y;
+
+ __shared__ acc_t s_scale;
+ __shared__ acc_t s_dxx;
+
+ if (threadIdx.x == 0) {
+ acc_t scale = rsqrtf(sum_square / d + eps);
+ s_dxx = d_sum * scale * scale * scale / d;
+ s_scale = scale;
+ }
+ __syncthreads();
+ acc_t scale = s_scale;
+ acc_t dxx = s_dxx;
+ vec_t *__restrict__ input_grad_vec = reinterpret_cast(input_grad);
+ dw_vec_t *__restrict__ temp_weight_grad_vec =
+ reinterpret_cast(temp_weight_grad);
+
+ for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) {
+ vec_t x_vec = input_vec[vec_offset + vidx];
+ vec_t dy_vec = output_grad_vec[vec_offset + vidx];
+ vec_t w_vec = weight_vec[vidx];
+
+ vec_t in_grad_vec;
+ dw_vec_t tw_grad_vec;
+
+#pragma unroll
+ for (int i = 0; i < width; ++i) {
+ acc_t x = x_vec.data[i];
+ acc_t dy = dy_vec.data[i];
+ acc_t w = w_vec.data[i];
+
+ if (input_grad) {
+ in_grad_vec.data[i] = scale * dy * w - dxx * x;
+ }
+ tw_grad_vec.data[i] = dy * x * scale;
+ }
+
+ if (input_grad) {
+ input_grad_vec[vec_offset + vidx] = in_grad_vec;
+ }
+ temp_weight_grad_vec[vec_offset + vidx] = tw_grad_vec;
+ }
+}
+
+template
+__global__ std::enable_if_t<(width == 0)>
rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
acc_t *__restrict__ temp_weight_grad, // [..., d]
const scalar_t *__restrict__ output_grad, // [..., d]
@@ -61,30 +222,55 @@ rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
sum_square += x * x;
}
- __shared__ acc_t shared[BLOCK_SIZE];
+ using BlockReduce = cub::BlockReduce;
+ __shared__ typename BlockReduce::TempStorage reduceStore;
+ struct SumOp {
+ __device__ float2 operator()(const float2 &a, const float2 &b) const {
+ return make_float2(a.x + b.x, a.y + b.y);
+ }
+ };
+ float2 thread_sums = make_float2(d_sum, sum_square);
+ float2 block_sums =
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
- d_sum = _block_reduce_sum(shared, d_sum, d);
- acc_t variance =
- _block_reduce_sum(shared, sum_square, d) / d;
- acc_t scale = rsqrt(variance + eps);
- acc_t scale_cubed = scale * scale * scale;
- acc_t dxx = d_sum * scale_cubed / d;
+ d_sum = block_sums.x;
+ sum_square = block_sums.y;
+
+ __shared__ acc_t s_scale;
+ __shared__ acc_t s_dxx;
+
+ if (threadIdx.x == 0) {
+ acc_t scale = rsqrtf(sum_square / d + eps);
+ s_dxx = d_sum * scale * scale * scale / d;
+ s_scale = scale;
+ }
+ __syncthreads();
+
+ acc_t scale = s_scale;
+ acc_t dxx = s_dxx;
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
acc_t x = input[token_idx * d + idx];
acc_t dy = output_grad[token_idx * d + idx];
acc_t w = weight[idx];
- input_grad[token_idx * d + idx] = scale * dy * w - dxx * x;
-
- if (temp_weight_grad) {
- temp_weight_grad[token_idx * d + idx] = dy * x * scale;
+ if (input_grad) {
+ input_grad[token_idx * d + idx] = scale * dy * w - dxx * x;
}
+ temp_weight_grad[token_idx * d + idx] = dy * x * scale;
}
}
} // namespace motif
+#define LAUNCH_RMS_NORM(width) \
+ MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { \
+ motif::rms_norm_kernel \
+ <<>>(out.data_ptr(), \
+ input.data_ptr(), \
+ weight.data_ptr(), eps, d); \
+ });
+
void rms_norm(torch::Tensor &out, // [..., d]
const torch::Tensor &input, // [..., d]
const torch::Tensor &weight, // [d]
@@ -93,27 +279,36 @@ void rms_norm(torch::Tensor &out, // [..., d]
AssertTensorNotNull(weight, "weight");
// TODO shape check
- constexpr int BLOCK_SIZE = 256;
-
int d = input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
- dim3 block(BLOCK_SIZE);
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
+ dim3 block(std::min(d, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
- motif::rms_norm_kernel
- <<>>(out.data_ptr(),
- input.data_ptr(),
- weight.data_ptr(), eps, d);
- });
+ if (d % 8 == 0) {
+ LAUNCH_RMS_NORM(8);
+ } else {
+ LAUNCH_RMS_NORM(0);
+ }
}
+#define LAUNCH_RMS_NORM_BWD(width) \
+ MOTIF_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), "rms_norm_backward_kernel", [&] { \
+ motif::rms_norm_backward_kernel \
+ <<>>(input_grad.data_ptr(), \
+ temp_weight_grad.data_ptr(), \
+ output_grad.data_ptr(), \
+ input.data_ptr(), \
+ weight.data_ptr(), eps, d); \
+ });
+
void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
- torch::Tensor &weight_grad, // [..., d]
- const torch::Tensor &output_grad, // [d]
- const torch::Tensor &input, // [d]
+ torch::Tensor &weight_grad, // [d]
+ const torch::Tensor &output_grad, // [..., d]
+ const torch::Tensor &input, // [..., d]
const torch::Tensor &weight, // [d]
double eps) {
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
@@ -122,30 +317,27 @@ void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
// TODO shape check
// weight_grad, input_grad can be nullable
- constexpr int BLOCK_SIZE = 256;
-
int d = input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
- dim3 block(BLOCK_SIZE);
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
+ dim3 block(std::min(d, max_block_size));
torch::Tensor temp_weight_grad =
torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
- MOTIF_DISPATCH_FLOATING_TYPES(
- input.scalar_type(), "rms_norm_backward_kernel", [&] {
- motif::rms_norm_backward_kernel
- <<>>(input_grad.data_ptr(),
- temp_weight_grad.data_ptr(),
- output_grad.data_ptr(),
- input.data_ptr(),
- weight.data_ptr(), eps, d);
- });
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ if (d % 8 == 0) {
+ LAUNCH_RMS_NORM_BWD(8);
+ } else {
+ LAUNCH_RMS_NORM_BWD(0);
+ }
if (weight_grad.defined()) {
- at::sum_out(weight_grad, temp_weight_grad, {0});
+ torch::Tensor acc =
+ torch::empty_like(weight_grad, temp_weight_grad.options());
+ at::sum_out(acc, temp_weight_grad, {0});
+ weight_grad.copy_(acc);
}
}
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e6ad09a2b1b28b342aac1ee2c58f2d12ac6d01ff
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,35 @@
+# Benchmark Runner
+
+This script benchmarks **forward/backward performance** of several operations (`rms`, `add_rms`, `poly`, `mul_poly`).
+Results can be saved as **CSV files** or **plots**.
+
+> **Note**
+> To run the benchmarks, you must select the appropriate Torch version along with the corresponding CUDA/ROCm build from within the `build` directory.
+>
+> **Example:**
+>
+> ```bash
+> export PYTHONPATH=$PYTHONPATH:/activation/build/torch27-cxx11-cu128-x86_64-linux
+> ```
+
+## Usage
+
+```bash
+python main.py --case [--plot] [--save-path ]
+```
+
+- `--case` (required): one of `rms`, `add_rms`, `poly`, `mul_poly`
+- `--plot`: save plots instead of CSVs
+- `--save-path`: output directory (default: `./configs/`)
+
+## Examples
+
+```bash
+python main.py --case add_rms --save-path ./results/
+python main.py --case poly --plot --save-path ./plots/
+```
+
+## Output
+
+- CSV: `-fwd-perf.csv`, `-bwd-perf.csv`
+- Plots: `plot_-fwd-perf.png`, `plot_-bwd-perf.png`
diff --git a/benchmarks/cases/__init__.py b/benchmarks/cases/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/benchmarks/cases/__init__.py
@@ -0,0 +1 @@
+
diff --git a/benchmarks/cases/add_rms.py b/benchmarks/cases/add_rms.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e055e197c2a9e8540c94b579e88db63824ce424
--- /dev/null
+++ b/benchmarks/cases/add_rms.py
@@ -0,0 +1,55 @@
+import torch
+from common.diff_engine import DiffCase
+
+import activation
+
+
+class FusedAddRMSNorm(torch.nn.Module):
+
+ def __init__(self, d, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(d, dtype=dtype))
+ self.eps = eps
+
+ def forward(self, x, residual):
+ return activation.rms_norm((x + residual), self.weight, self.eps)
+
+
+class AddRMS(DiffCase):
+
+ def build_inputs(self, bs, sl, hidden, dtype, eps):
+ return {
+ "x":
+ torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
+ "residual":
+ torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
+ "weight":
+ torch.ones(hidden, dtype=dtype),
+ "dim":
+ hidden,
+ "eps":
+ eps,
+ "dtype":
+ dtype,
+ }
+
+ def make_naive(self, I):
+ m = FusedAddRMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ return m
+
+ def make_cuda(self, I):
+ m = activation.layers.FusedAddRMSNorm(I["dim"],
+ I["eps"],
+ dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ return m
+
+ def forward(self, obj, I):
+ return obj(I["x"], I["residual"])
+
+ def grad_inputs(self, I):
+ return [I["x"], I["residual"]]
+
+
+CASE = AddRMS()
diff --git a/benchmarks/cases/mul_poly.py b/benchmarks/cases/mul_poly.py
new file mode 100644
index 0000000000000000000000000000000000000000..48597c3cf321915910fdb0f02c607a03fbab4a3e
--- /dev/null
+++ b/benchmarks/cases/mul_poly.py
@@ -0,0 +1,53 @@
+import torch
+from common.diff_engine import DiffCase
+
+import activation
+
+
+class FusedMulPolyNorm(torch.nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(self, x, mul):
+ output = activation.poly_norm(x, self.weight, self.bias, self.eps)
+ return output * mul
+
+
+class MulPoly(DiffCase):
+
+ def build_inputs(self, bs, sl, hidden, dtype, eps):
+ return {
+ "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
+ "mul": torch.randn(bs, sl, hidden, dtype=dtype,
+ requires_grad=True),
+ "weight": torch.ones(3, dtype=dtype),
+ "bias": torch.ones(1, dtype=dtype),
+ "dim": hidden,
+ "eps": eps,
+ "dtype": dtype,
+ }
+
+ def make_naive(self, I):
+ m = FusedMulPolyNorm(I["eps"], dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ m.bias = torch.nn.Parameter(I["bias"].detach().clone())
+ return m
+
+ def make_cuda(self, I):
+ m = activation.layers.FusedMulPolyNorm(I["eps"], dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ m.bias = torch.nn.Parameter(I["bias"].detach().clone())
+ return m
+
+ def forward(self, obj, I):
+ return obj(I["x"], I["mul"])
+
+ def grad_inputs(self, I):
+ return [I["x"], I["mul"]]
+
+
+CASE = MulPoly()
diff --git a/benchmarks/cases/poly.py b/benchmarks/cases/poly.py
new file mode 100644
index 0000000000000000000000000000000000000000..00efede1e6a911e3f278343a9d49c36e619c7684
--- /dev/null
+++ b/benchmarks/cases/poly.py
@@ -0,0 +1,58 @@
+import torch
+from common.diff_engine import DiffCase
+
+import activation
+
+
+class PolyNorm(torch.nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def _norm(self, x):
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ orig_dtype = x.dtype
+ x_float = x.to(torch.float32)
+ output = (self.weight[0] * self._norm(x_float**3) +
+ self.weight[1] * self._norm(x_float**2) +
+ self.weight[2] * self._norm(x_float) + self.bias)
+ return output.to(orig_dtype)
+
+
+class Poly(DiffCase):
+
+ def build_inputs(self, bs, sl, hidden, dtype, eps):
+ return {
+ "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
+ "weight": torch.ones(3, dtype=dtype),
+ "bias": torch.ones(1, dtype=dtype),
+ "dim": hidden,
+ "eps": eps,
+ "dtype": dtype,
+ }
+
+ def make_naive(self, I):
+ m = PolyNorm(I["eps"], dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ m.bias = torch.nn.Parameter(I["bias"].detach().clone())
+ return m
+
+ def make_cuda(self, I):
+ m = activation.layers.PolyNorm(I["eps"], dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ m.bias = torch.nn.Parameter(I["bias"].detach().clone())
+ return m
+
+ def forward(self, obj, I):
+ return obj(I["x"])
+
+ def grad_inputs(self, I):
+ return [I["x"]]
+
+
+CASE = Poly()
diff --git a/benchmarks/cases/rms.py b/benchmarks/cases/rms.py
new file mode 100644
index 0000000000000000000000000000000000000000..15331e8f42f5b4405b4d3a8317f3fa00b56c62ae
--- /dev/null
+++ b/benchmarks/cases/rms.py
@@ -0,0 +1,35 @@
+import torch
+from common.diff_engine import DiffCase
+
+import activation
+
+
+class RMS(DiffCase):
+
+ def build_inputs(self, bs, sl, hidden, dtype, eps):
+ return {
+ "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
+ "weight": torch.ones(hidden, dtype=dtype),
+ "dim": hidden,
+ "eps": eps,
+ "dtype": dtype,
+ }
+
+ def make_naive(self, I):
+ m = torch.nn.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ return m
+
+ def make_cuda(self, I):
+ m = activation.layers.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
+ m.weight = torch.nn.Parameter(I["weight"].detach().clone())
+ return m
+
+ def forward(self, obj, I):
+ return obj(I["x"])
+
+ def grad_inputs(self, I):
+ return [I["x"]]
+
+
+CASE = RMS()
diff --git a/benchmarks/common/__init__.py b/benchmarks/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/benchmarks/common/__init__.py
@@ -0,0 +1 @@
+
diff --git a/benchmarks/common/bench_framework.py b/benchmarks/common/bench_framework.py
new file mode 100644
index 0000000000000000000000000000000000000000..f24b8d8d6163a85ce347eeb3d3ad06f46f1c67cf
--- /dev/null
+++ b/benchmarks/common/bench_framework.py
@@ -0,0 +1,220 @@
+import collections
+import math
+import re
+from typing import Any, Dict, Sequence
+
+import torch
+import triton
+
+from .diff_engine import DiffCase
+
+
+def make_fwd_key(batch_size, seq_len, dim):
+ return f"forward : ({batch_size}, {seq_len}, {dim})"
+
+
+def make_bwd_key(batch_size, seq_len, dim):
+ return f"backward : ({batch_size}, {seq_len}, {dim})"
+
+
+def parse_config_string(config_str):
+ match = re.match(r"(\w+)\s*:\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)",
+ config_str)
+ if not match:
+ raise ValueError(f"Invalid config string: {config_str}")
+ _, bs, sl, d = match.groups()
+ return int(bs), int(sl), int(d)
+
+
+def make_fwd_benchmark_for_case(
+ *,
+ case: DiffCase,
+ configs: Sequence[tuple[int, int, int]],
+ plot_name: str,
+ ylabel: str = "us",
+ line_vals=("naive", "cuda", "speedup"),
+ line_names: Dict[str, str] | None = None,
+ dtype=torch.bfloat16,
+ eps: float = 1e-6,
+ time_unit_scale: float = 1000,
+):
+ timings_ms = collections.defaultdict(dict)
+ line_vals = list(line_vals)
+ line_names = line_names or {v: v.title() for v in line_vals}
+ x_vals = [list(_) for _ in configs]
+
+ @triton.testing.perf_report(
+ triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"],
+ x_vals=x_vals,
+ line_arg="provider",
+ line_vals=line_vals,
+ line_names=[line_names[v] for v in line_vals],
+ ylabel=ylabel,
+ plot_name=plot_name,
+ args={}))
+ def bench(dim, batch_size, seq_len, provider):
+ key = make_fwd_key(dim, batch_size, seq_len)
+ I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
+ if provider == "speedup":
+ return timings_ms["naive"][key] / timings_ms["cuda"][key]
+ obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ run = lambda: case.forward(obj, I)
+ ms = triton.testing.do_bench(run)
+ timings_ms[provider][key] = ms
+ return time_unit_scale * ms
+
+ return bench
+
+
+def make_fwd_benchmark_plot_for_case(
+ *,
+ case: DiffCase,
+ configs: Sequence[tuple[int, int, int]],
+ plot_name: str,
+ ylabel: str = "Relative Speedup",
+ line_vals=("naive", "cuda"),
+ line_names: Dict[str, str] | None = None,
+ dtype=torch.bfloat16,
+ eps: float = 1e-6,
+):
+ timings_ms = collections.defaultdict(dict)
+ spdup_ratio = list()
+ line_vals = list(line_vals)
+ line_names = line_names or {v: v.title() for v in line_vals}
+ x_vals = [make_fwd_key(*_) for _ in configs]
+ x_vals.append("Geometric Mean")
+
+ @triton.testing.perf_report(
+ triton.testing.Benchmark(x_names=["config"],
+ x_vals=x_vals,
+ line_arg="provider",
+ line_vals=line_vals,
+ line_names=[line_names[v] for v in line_vals],
+ ylabel=ylabel,
+ plot_name=plot_name,
+ args={}))
+ def bench(config, provider):
+ if config == "Geometric Mean":
+ if provider == "cuda":
+ return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2)
+ else:
+ return 1.00
+ batch_size, seq_len, dim = parse_config_string(config)
+ I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
+ obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ run = lambda: case.forward(obj, I)
+ ms = triton.testing.do_bench(run)
+ timings_ms[provider][config] = ms
+ if provider == "cuda":
+ ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
+ spdup_ratio.append(ratio)
+ return round(ratio, 2)
+ else:
+ return 1.00
+
+ return bench
+
+
+def make_bwd_benchmark_for_case(
+ *,
+ case: DiffCase,
+ configs: Sequence[tuple[int, int, int]],
+ plot_name: str,
+ ylabel: str = "us",
+ line_vals=("naive", "cuda", "speedup"),
+ line_names: Dict[str, str] | None = None,
+ dtype=torch.bfloat16,
+ eps: float = 1e-6,
+ time_unit_scale: float = 1000,
+):
+ timings_ms = collections.defaultdict(dict)
+ line_vals = list(line_vals)
+ line_names = line_names or {v: v.title() for v in line_vals}
+ x_vals = [list(_) for _ in configs]
+
+ @triton.testing.perf_report(
+ triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"],
+ x_vals=x_vals,
+ line_arg="provider",
+ line_vals=line_vals,
+ line_names=[line_names[v] for v in line_vals],
+ ylabel=ylabel,
+ plot_name=plot_name,
+ args={}))
+ def bench(dim, batch_size, seq_len, provider):
+ key = make_bwd_key(dim, batch_size, seq_len)
+ I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
+ if provider == "speedup":
+ return timings_ms["naive"][key] / timings_ms["cuda"][key]
+ obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ y = case.forward(obj, I)
+ gin = list(case.grad_inputs(I)) + list(obj.parameters())
+ g = torch.randn_like(y)
+ run = lambda: torch.autograd.grad(y,
+ gin,
+ g,
+ retain_graph=True,
+ create_graph=False,
+ allow_unused=False)
+ ms = triton.testing.do_bench(run)
+ timings_ms[provider][key] = ms
+ return time_unit_scale * ms
+
+ return bench
+
+
+def make_bwd_benchmark_plot_for_case(
+ *,
+ case: DiffCase,
+ configs: Sequence[tuple[int, int, int]],
+ plot_name: str,
+ ylabel: str = "Relative Speedup",
+ line_vals=("naive", "cuda"),
+ line_names: Dict[str, str] | None = None,
+ dtype=torch.bfloat16,
+ eps: float = 1e-6,
+):
+ timings_ms = collections.defaultdict(dict)
+ spdup_ratio = list()
+ line_vals = list(line_vals)
+ line_names = line_names or {v: v.title() for v in line_vals}
+ x_vals = [make_bwd_key(*_) for _ in configs]
+ x_vals.append("Geometric Mean")
+
+ @triton.testing.perf_report(
+ triton.testing.Benchmark(x_names=["config"],
+ x_vals=x_vals,
+ line_arg="provider",
+ line_vals=line_vals,
+ line_names=[line_names[v] for v in line_vals],
+ ylabel=ylabel,
+ plot_name=plot_name,
+ args={}))
+ def bench(config, provider):
+ if config == "Geometric Mean":
+ if provider == "cuda":
+ return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2)
+ else:
+ return 1.00
+ batch_size, seq_len, dim = parse_config_string(config)
+ I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
+ obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ y = case.forward(obj, I)
+ gin = list(case.grad_inputs(I)) + list(obj.parameters())
+ g = torch.randn_like(y)
+ run = lambda: torch.autograd.grad(y,
+ gin,
+ g,
+ retain_graph=True,
+ create_graph=False,
+ allow_unused=False)
+ ms = triton.testing.do_bench(run)
+ timings_ms[provider][config] = ms
+ if provider == "cuda":
+ ratio = timings_ms["naive"][config] / timings_ms["cuda"][config]
+ spdup_ratio.append(ratio)
+ return round(ratio, 2)
+ else:
+ return 1.00
+
+ return bench
diff --git a/benchmarks/common/diff_engine.py b/benchmarks/common/diff_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..276cd3900c34a69740cb35ec3183491055b32751
--- /dev/null
+++ b/benchmarks/common/diff_engine.py
@@ -0,0 +1,85 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Sequence
+
+import torch
+
+
+class DiffCase(ABC):
+
+ @abstractmethod
+ def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype,
+ eps: float) -> Dict[str, Any]:
+ ...
+
+ @abstractmethod
+ def make_naive(self, I: Dict[str, Any]) -> Any:
+ ...
+
+ @abstractmethod
+ def make_cuda(self, I: Dict[str, Any]) -> Any:
+ ...
+
+ @abstractmethod
+ def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor:
+ ...
+
+ @abstractmethod
+ def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]:
+ ...
+
+
+def _clone_payload(d, device):
+ out = {}
+ for k, v in d.items():
+ if isinstance(v, torch.Tensor):
+ t = v.detach().clone().to(device)
+ t.requires_grad_(v.requires_grad)
+ out[k] = t
+ else:
+ out[k] = v
+ return out
+
+
+def _unit_grad_like(y):
+ g = torch.randn_like(y)
+ n = g.norm()
+ return g if n == 0 else g / n
+
+
+def calculate_diff(
+ case: DiffCase,
+ *,
+ batch_size: int,
+ seq_len: int,
+ hidden_size: int,
+ dtype=torch.bfloat16,
+ eps: float = 1e-6,
+ atol: float = 1e-2,
+ rtol: float = 1e-2,
+ device="cuda",
+) -> None:
+ base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps)
+ I_n = _clone_payload(base, device)
+ I_c = _clone_payload(base, device)
+ obj_n = case.make_naive(I_n)
+ obj_c = case.make_cuda(I_c)
+ y_n = case.forward(obj_n, I_n)
+ y_c = case.forward(obj_c, I_c)
+ torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol)
+ gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters())
+ gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters())
+ g = _unit_grad_like(y_n).to(device)
+ ng = torch.autograd.grad(y_n,
+ gin_n,
+ g,
+ retain_graph=False,
+ create_graph=False,
+ allow_unused=False)
+ cg = torch.autograd.grad(y_c,
+ gin_c,
+ g,
+ retain_graph=False,
+ create_graph=False,
+ allow_unused=False)
+ torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol)
+ print("✅ forward + backward match")
diff --git a/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png b/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..b596caf613cf58456d23286293dca72d747f377f
Binary files /dev/null and b/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png differ
diff --git a/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png b/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..132e33a45d291c152b44398be9f5d75d2abaa232
Binary files /dev/null and b/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png differ
diff --git a/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png b/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..d8919a93e3d7be46a9d6cd57d832eff48377a637
Binary files /dev/null and b/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png differ
diff --git a/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png b/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..5191efe759849638f7ac9f0f619c8eb56a04975d
Binary files /dev/null and b/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png differ
diff --git a/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png b/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..06d36211ff47c6c1a6feac15d5ea1c708c289ece
Binary files /dev/null and b/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png differ
diff --git a/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png b/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..59f9cbd32aa6532bfcf5284f51a3d56be9d3da9d
Binary files /dev/null and b/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png differ
diff --git a/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png b/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..95815d90490f6c5916fa4835271473ac2565b079
Binary files /dev/null and b/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png differ
diff --git a/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png b/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b86645cfecfbb0e9645527530a9fbd092f68b5f
Binary files /dev/null and b/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png differ
diff --git a/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png b/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..f36820d99706b319eeda8f8a3ed5246383d2524f
Binary files /dev/null and b/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png differ
diff --git a/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png b/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..dacca9cf397419f9af992675388f7299b35e30c0
Binary files /dev/null and b/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png differ
diff --git a/benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..3145730e039fd475d3fb990f0acaafc0dd1dd70e
Binary files /dev/null and b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png differ
diff --git a/benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..ba02f90b3811c6aa78f7d34aa856d8047a5cec34
Binary files /dev/null and b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png differ
diff --git a/benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png b/benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..327e7267c2bcff9842032d418855539e4757fefa
Binary files /dev/null and b/benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png differ
diff --git a/benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png b/benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..71544c6ab53b849ab0770442cc7d409044529aa7
Binary files /dev/null and b/benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png differ
diff --git a/benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png b/benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..95bd8ee01a111de532b6fc2c9e957ac48df3efe2
Binary files /dev/null and b/benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png differ
diff --git a/benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png b/benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png
new file mode 100644
index 0000000000000000000000000000000000000000..65e5fa7275861bd5f36ee81e576e27f11e1975b3
Binary files /dev/null and b/benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png differ
diff --git a/benchmarks/run_cases.py b/benchmarks/run_cases.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2e1f746ccf14aed82182c482dc39c05fe03816e
--- /dev/null
+++ b/benchmarks/run_cases.py
@@ -0,0 +1,143 @@
+import argparse
+import glob
+import importlib
+import itertools
+import os
+
+import torch
+from common.bench_framework import (make_bwd_benchmark_for_case,
+ make_bwd_benchmark_plot_for_case,
+ make_fwd_benchmark_for_case,
+ make_fwd_benchmark_plot_for_case)
+from common.diff_engine import DiffCase, calculate_diff
+
+
+def make_title_tag():
+ if torch.cuda.is_available():
+ dev_name = torch.cuda.get_device_name(0)
+ else:
+ dev_name = "CPU"
+
+ torch_ver = torch.__version__
+
+ return f"[{dev_name} | torch {torch_ver}]"
+
+
+def plot_result(r_path):
+ import matplotlib.pyplot as plt
+ import pandas as pd
+ df = pd.read_csv(r_path + ".csv")
+ plt.figure(figsize=(12, 6))
+ ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca())
+ ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
+ fontsize=14,
+ fontweight="bold")
+ ax.set_ylabel("Relative Speedup", fontsize=14)
+ ax.set_xlabel("")
+ plt.xticks(rotation=45, fontsize=12, ha="right", rotation_mode="anchor")
+ for container in ax.containers:
+ labels = [f"x{v.get_height():.2f}" for v in container]
+ ax.bar_label(container, labels=labels, label_type="edge", fontsize=10)
+ plt.tight_layout()
+ plt.savefig(r_path + ".png", bbox_inches="tight")
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--case",
+ choices=["rms", "add_rms", "poly", "mul_poly"],
+ required=True)
+ ap.add_argument("--plot", action="store_true")
+ ap.add_argument(
+ "--save-path",
+ type=str,
+ default="./configs/",
+ help="Path to save benchmark results",
+ )
+ args = ap.parse_args()
+
+ torch.set_default_device("cuda")
+ mod = importlib.import_module(f"cases.{args.case}")
+ case: DiffCase = mod.CASE
+
+ calculate_diff(
+ case,
+ batch_size=2,
+ seq_len=128,
+ hidden_size=4096,
+ )
+
+ save_dir = os.path.join(args.save_path, args.case)
+ if args.plot:
+ batch_size_range = [1]
+ seq_length_range = [4096, 8192, 16384]
+ dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
+ configs = list(
+ itertools.product(batch_size_range, seq_length_range, dim))
+ plot_name = f"plot_{args.case}-fwd-perf"
+ bench = make_fwd_benchmark_plot_for_case(
+ case=case,
+ configs=configs,
+ plot_name=plot_name,
+ line_names={
+ "naive": "Naive",
+ "cuda": "Cuda",
+ },
+ )
+ bench.run(print_data=True, save_path=save_dir)
+ plot_result(os.path.join(save_dir, plot_name))
+
+ plot_name = f"plot_{args.case}-bwd-perf"
+ bench = make_bwd_benchmark_plot_for_case(
+ case=case,
+ configs=configs,
+ plot_name=plot_name,
+ line_names={
+ "naive": "Naive",
+ "cuda": "Cuda",
+ },
+ )
+ bench.run(print_data=True, save_path=save_dir)
+ plot_result(os.path.join(save_dir, plot_name))
+ for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
+ os.path.join(save_dir, "*.csv")):
+ os.remove(f)
+ else:
+ batch_size_range = [2**i for i in range(0, 4, 1)]
+ seq_length_range = [2**i for i in range(10, 14, 1)]
+ dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
+ configs = list(
+ itertools.product(dim, batch_size_range, seq_length_range))
+
+ bench = make_fwd_benchmark_for_case(
+ case=case,
+ configs=configs,
+ plot_name=f"{args.case}-fwd-perf",
+ line_names={
+ "naive": "Naive",
+ "cuda": "Cuda",
+ "speedup": "SpeedUp"
+ },
+ )
+
+ bench.run(print_data=True, save_path=save_dir)
+
+ bench = make_bwd_benchmark_for_case(
+ case=case,
+ configs=configs,
+ plot_name=f"{args.case}-bwd-perf",
+ line_names={
+ "naive": "Naive",
+ "cuda": "Cuda",
+ "speedup": "SpeedUp"
+ },
+ )
+
+ bench.run(print_data=True, save_path=save_dir)
+ for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
+ os.path.join(save_dir, "*.png")):
+ os.remove(f)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/build.toml b/build.toml
index ce1953079d04f1265c38cdbc6080bbecc7705d1c..a61a661991131939392cb8ab771f57b320b68867 100644
--- a/build.toml
+++ b/build.toml
@@ -13,9 +13,10 @@ backend = "rocm"
rocm-archs = [ "gfx90a", "gfx942" ]
src = [
"activation/poly_norm.cu",
+ "activation/fused_mul_poly_norm.cu",
"activation/rms_norm.cu",
+ "activation/fused_add_rms_norm.cu",
"activation/cuda_compat.h",
- "activation/block_reduce.h",
"activation/dispatch_utils.h",
"activation/assert_utils.h",
"activation/atomic_utils.h",
@@ -26,9 +27,10 @@ depends = [ "torch" ]
backend = "cuda"
src = [
"activation/poly_norm.cu",
+ "activation/fused_mul_poly_norm.cu",
"activation/rms_norm.cu",
+ "activation/fused_add_rms_norm.cu",
"activation/cuda_compat.h",
- "activation/block_reduce.h",
"activation/dispatch_utils.h",
"activation/assert_utils.h",
"activation/atomic_utils.h",
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py
+++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..1b3674d54c044dddf2d037c1d3bac522bc19440c
--- /dev/null
+++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d21a85bf21aa74f1281541e658acfd4f4326d902efe3578b059eccf054443284
+size 8089696
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py
+++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py
+++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py
+++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py
+++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py
+++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..df3c3ae7785a3c30c36d900923c1dd7a349448db
--- /dev/null
+++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74d4955271509451b946495da75f69a0f978e7258b8303fe3c077e585c0d3e6a
+size 8272456
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py
+++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py
+++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py
+++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py
+++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py
+++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..0de3488964fc7207148b7b9b62cc4db838e64c7b
--- /dev/null
+++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bf0d2ab5ff5520704e0b0c959b61d0043d360cfd4335950e69677873a87e436
+size 12792112
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py
+++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py
+++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py
+++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py
+++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py
+++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..57361102c13046a6a1aab2f7125193ece35b21da
--- /dev/null
+++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:640322a8fac8fd9d8e9f195a3034c4ee0f81ee1acf897fd7c482a84ce47a1bec
+size 4160688
diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py
+++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py
+++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py
+++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py
+++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py
+++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..c703b3b19594e8b20ee5b4dc7692fbdad8079365
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1768d8d5072ac06d937cb5332988c6b3bfaa191f72d1369a22d2c577e9a3bca2
+size 8215280
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py
+++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py
+++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py
+++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py
+++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py
+++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..ecdc467a674247fe3898453418ce88a9983d08c5
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:37a572bd877980ab8c0331ca5682191cb5a2b1f05bc69ea493a9e24f7728ba3f
+size 12730840
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py
+++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py
+++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py
+++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py
+++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py
+++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..d6c8a74ea050b78cf9dcd4c43ac618094b0ca303
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f15919c4cac697cde550af16256e338472400e50df751e93622350c7f626bc8
+size 12726208
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py
+++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py
+++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py
+++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py
+++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..670a8291fdc208c690447600ee77449e1fac9929
--- /dev/null
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e72d4bb4459a5da96ca5eda1d305237a361140f0e25360e3d20326a22f1b6d47
+size 4165584
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so
new file mode 100644
index 0000000000000000000000000000000000000000..c8f702b9ecfdc1c01dcdd2880d088458c4f11c2d
--- /dev/null
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c3325c2748cf7a070383068995078f93f440cc95fbed491d00bd414cdd851376
+size 4171472
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py
index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644
--- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_f517c97_dirty
-ops = torch.ops._activation_f517c97_dirty
+from . import _activation_20250907180255
+ops = torch.ops._activation_20250907180255
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_f517c97_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_20250907180255::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/tests/kernels/allclose_default.py b/tests/allclose_default.py
similarity index 100%
rename from tests/kernels/allclose_default.py
rename to tests/allclose_default.py
diff --git a/tests/conftest.py b/tests/conftest.py
deleted file mode 100644
index eda6569be97917dc009f15b3ca4f11136708aa0c..0000000000000000000000000000000000000000
--- a/tests/conftest.py
+++ /dev/null
@@ -1,144 +0,0 @@
-import logging
-
-import numpy as np
-import plotly.graph_objects as go
-import pytest
-
-from .kernels.test_poly_norm_perf import PERF_RESULTS, PerfResult
-
-logger = logging.getLogger(__name__)
-DO_PLOT = False
-
-
-def plot(perf_results: list[PerfResult]):
- x_labels = [f"{r.type}, {r.shape}, {r.dtype}" for r in perf_results]
- kernel_speedup = [r.speedup for r in perf_results]
- torch_speedup = [1 for _ in perf_results]
-
- geo_mean = float(np.exp(np.mean(np.log(kernel_speedup))))
- x_labels.append("Geometric Mean")
- kernel_speedup.append(geo_mean)
- torch_speedup.append(1.0)
-
- fig = go.Figure()
-
- bar_width = 0.2
- fig.add_trace(
- go.Bar(
- x=x_labels,
- y=kernel_speedup,
- name="Activation",
- marker_color="rgb(100, 100, 100)",
- text=[f"x{v:.2f}" for v in kernel_speedup],
- textfont=dict(size=14),
- textposition="outside",
- # width=[bar_width] * len(x_labels),
- ))
-
- fig.add_trace(
- go.Bar(
- x=x_labels,
- y=torch_speedup,
- name="Torch",
- marker_color="rgb(30, 30, 30)",
- text=[f"x{v:.2f}" for v in torch_speedup],
- textfont=dict(size=14),
- textposition="outside",
- # width=[bar_width] * len(x_labels),
- ))
-
- fig.update_layout(
- title=dict(
- text=
- "Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)",
- font=dict(size=24),
- ),
- legend=dict(
- x=0.01,
- y=0.99,
- xanchor="left",
- yanchor="top",
- bgcolor="rgba(0,0,0,0)",
- bordercolor="black",
- borderwidth=1,
- ),
- font=dict(size=16),
- yaxis_title="Speedup (torch / activation)",
- barmode="group",
- bargroupgap=0,
- bargap=0.2,
- xaxis_tickangle=-45,
- template="plotly_white",
- yaxis_type="log",
- shapes=[
- dict(
- type="rect",
- xref="x",
- yref="paper", # y축 전체 범위 (0~1)
- x0=-0.5,
- x1=len(x_labels) - 0.5,
- y0=0,
- y1=1,
- line=dict(
- color="black",
- width=1.5,
- ),
- fillcolor="rgba(0,0,0,0)", # 투명 배경
- layer="above", # bar 아래에 그리기
- )
- ],
- )
-
- output_file = "perf_result.html"
- fig.write_html(output_file)
- logger.info(f"Plotting performance results to {output_file}")
-
-
-def pytest_addoption(parser):
- parser.addoption("--run-perf",
- action="store_true",
- default=False,
- help="Run perf tests")
- parser.addoption("--do-plot",
- action="store_true",
- default=False,
- help="Plot performance results")
-
-
-@pytest.fixture
-def do_plot(request):
- return request.config.getoption("--do-plot")
-
-
-def pytest_configure(config):
- global DO_PLOT
- DO_PLOT = config.getoption("--do-plot")
- run_perf = config.getoption("--run-perf")
-
- if DO_PLOT and not run_perf:
- raise ValueError(
- "Cannot plot performance results without running performance tests. "
- "Please use --run-perf option.")
-
- config.addinivalue_line("markers",
- "perf: mark test as performance-related")
-
-
-def pytest_collection_modifyitems(config, items):
- run_perf = config.getoption("--run-perf")
-
- skip_perf = pytest.mark.skip(reason="need --run-perf option to run")
- skip_normal = pytest.mark.skip(
- reason="normal tests skipped when --run-perf is used")
- for item in items:
- if "perf" in item.keywords and not run_perf:
- item.add_marker(skip_perf)
- elif "perf" not in item.keywords and run_perf:
- item.add_marker(skip_normal)
-
-
-def pytest_sessionfinish(session, exitstatus) -> None:
- if DO_PLOT:
- plot(PERF_RESULTS)
- else:
- logger.info(PERF_RESULTS)
diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/tests/kernels/test_poly_norm_perf.py b/tests/kernels/test_poly_norm_perf.py
deleted file mode 100644
index 88c35f34e8acb6cb835eedb7c1d6657e16b28da3..0000000000000000000000000000000000000000
--- a/tests/kernels/test_poly_norm_perf.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import random
-from dataclasses import dataclass
-
-import pytest
-import torch
-
-import activation
-
-from .test_poly_norm import poly_norm
-from .utils import assert_close
-
-CASES = [
- ((1, 2048, 8192), torch.bfloat16),
- ((1, 2048, 16384), torch.bfloat16),
- ((1, 16384, 8192), torch.bfloat16),
- ((1, 16384, 16384), torch.bfloat16),
-]
-NUM_REP = 100
-
-
-@dataclass
-class PerfResult:
- type: str # forward or backward
- shape: tuple
- dtype: torch.dtype
- kernel_time_ms: float
- torch_time_ms: float
-
- @property
- def speedup(self) -> float:
- return self.torch_time_ms / self.kernel_time_ms
-
-
-PERF_RESULTS: list[PerfResult] = []
-
-
-@pytest.mark.parametrize("cases", CASES)
-@pytest.mark.perf
-def test_poly_norm(
- cases: tuple,
- do_plot: bool,
-) -> None:
- random.seed(12345)
- torch.manual_seed(12345)
-
- torch.set_default_device("cuda")
-
- shape, dtype = cases
- x = torch.randn(shape, dtype=dtype, requires_grad=True)
- weight = torch.randn(3, dtype=dtype, requires_grad=True)
- bias = torch.randn(1, dtype=dtype, requires_grad=True)
- eps = 1e-05
-
- x.retain_grad()
- weight.retain_grad()
- bias.retain_grad()
- # To separate gradient computation, clone the inputs
-
- x_ref = x.detach().clone().requires_grad_(True)
- weight_ref = weight.detach().clone().requires_grad_(True)
- bias_ref = bias.detach().clone().requires_grad_(True)
-
- torch_fn = poly_norm
- layer = activation.layers.PolyNorm(eps)
- layer.weight = torch.nn.Parameter(weight)
- layer.bias = torch.nn.Parameter(bias)
-
- # Check correctness
- mod_out = layer(x)
- ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
- assert_close(mod_out, ref_out)
-
- out_grad = torch.rand_like(ref_out)
- out_grad = out_grad / out_grad.norm()
-
- ref_out.backward(out_grad, retain_graph=True)
- mod_out.backward(out_grad, retain_graph=True)
-
- assert_close(x.grad, x_ref.grad)
- assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
- assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
-
- def time_cuda(fn):
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
-
- for _ in range(5):
- fn()
- start.record()
- for _ in range(NUM_REP):
- fn()
- end.record()
- torch.cuda.synchronize()
- return start.elapsed_time(end) / NUM_REP
-
- kernel_time_ms = time_cuda(lambda: layer(x))
- torch_fn_time = time_cuda(
- lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))
-
- PERF_RESULTS.append(
- PerfResult(
- type="forward",
- shape=shape,
- dtype=dtype,
- kernel_time_ms=kernel_time_ms,
- torch_time_ms=torch_fn_time,
- ))
-
- kernel_time_ms = time_cuda(
- lambda: mod_out.backward(out_grad, retain_graph=True))
- torch_fn_time = time_cuda(
- lambda: ref_out.backward(out_grad, retain_graph=True))
-
- PERF_RESULTS.append(
- PerfResult(
- type="backward",
- shape=shape,
- dtype=dtype,
- kernel_time_ms=kernel_time_ms,
- torch_time_ms=torch_fn_time,
- ))
diff --git a/tests/perf.png b/tests/perf.png
deleted file mode 100644
index 398b315aff1f0af712c8b0996bacfe78992c49ed..0000000000000000000000000000000000000000
--- a/tests/perf.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:12f88f9ac4511cb37f38a34e3572e4347bd0c857144a4aaf64bd5981d6b50877
-size 165982
diff --git a/tests/perf_result.html b/tests/perf_result.html
deleted file mode 100644
index ccd99a78e8c7212caa63e79e7035477b94576b8c..0000000000000000000000000000000000000000
--- a/tests/perf_result.html
+++ /dev/null
@@ -1,3885 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/tests/test_fused_add_rms_norm.py b/tests/test_fused_add_rms_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5486471c131e4543917cf7f399553e6dc71a7a72
--- /dev/null
+++ b/tests/test_fused_add_rms_norm.py
@@ -0,0 +1,101 @@
+import random
+
+import pytest
+import torch
+
+import activation
+
+from .utils import assert_close, opcheck
+
+DTYPES = [torch.float, torch.bfloat16, torch.half]
+NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing
+D = [1, 7, 512, 13824] # Arbitrary values for testing
+SEEDS = [0]
+CUDA_DEVICES = [
+ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
+]
+
+
+def add_rms_norm_all_naive(x: torch.Tensor, residual: torch.Tensor,
+ weight: torch.Tensor, eps: float) -> torch.Tensor:
+ return torch.nn.functional.rms_norm((x + residual), weight.shape, weight,
+ eps)
+
+
+#use rms_norm kernel
+def add_rms_norm_partial_naive(x: torch.Tensor, residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float) -> torch.Tensor:
+ return activation.rms_norm((x + residual), weight, eps)
+
+
+@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("d", D)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_fused_add_rms_norm(
+ num_tokens: int,
+ d: int,
+ dtype: torch.dtype,
+ seed: int,
+ device: str,
+) -> None:
+ random.seed(seed)
+ torch.manual_seed(seed)
+ torch.set_default_device(device)
+
+ x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
+ residual = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
+ weight = torch.randn(d, dtype=dtype, requires_grad=True)
+ eps = 1e-05
+
+ x.retain_grad()
+ residual.retain_grad()
+ weight.retain_grad()
+ # To separate gradient computation, clone the inputs
+
+ x_ref = x.detach().clone().requires_grad_(True)
+ residual_ref = residual.detach().clone().requires_grad_(True)
+ weight_ref = weight.detach().clone().requires_grad_(True)
+
+ x_ref2 = x.detach().clone().requires_grad_(True)
+ residual_ref2 = residual.detach().clone().requires_grad_(True)
+ weight_ref2 = weight.detach().clone().requires_grad_(True)
+
+ torch_fn = add_rms_norm_all_naive
+ torch_fn2 = add_rms_norm_partial_naive
+
+ op = activation.ops.fused_add_rms_norm
+ fn = activation.fused_add_rms_norm
+
+ layer = activation.layers.FusedAddRMSNorm(d, eps)
+ layer.weight = torch.nn.Parameter(weight)
+
+ out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
+ add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
+ opcheck(op, (out, add_out, x, residual, weight, eps))
+
+ out = fn(x, residual, weight, eps)
+ mod_out = layer(x, residual)
+ ref_out = torch_fn(x_ref, residual_ref, weight_ref, eps)
+ ref_out2 = torch_fn2(x_ref2, residual_ref2, weight_ref2, eps)
+
+ assert_close(out, ref_out)
+ assert_close(out, ref_out2)
+ assert_close(mod_out, out, atol=0.0, rtol=0.0)
+
+ # test backward pass
+ out_grad = torch.randn_like(out)
+ out_grad = out_grad / out_grad.norm()
+
+ ref_out.backward(out_grad)
+ ref_out2.backward(out_grad)
+ mod_out.backward(out_grad)
+
+ assert_close(x.grad, x_ref.grad)
+ assert_close(x.grad, x_ref2.grad)
+ assert_close(residual.grad, residual_ref.grad)
+ assert_close(residual.grad, residual_ref2.grad)
+ assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
+ assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05)
diff --git a/tests/test_fused_mul_poly_norm.py b/tests/test_fused_mul_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f77c3d32bdbe0b87d51ffacc5280fec497cd81b
--- /dev/null
+++ b/tests/test_fused_mul_poly_norm.py
@@ -0,0 +1,120 @@
+import random
+
+import pytest
+import torch
+
+import activation
+
+from .utils import assert_close, opcheck
+
+DTYPES = [torch.float, torch.bfloat16, torch.half]
+NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing
+D = [1, 7, 512, 13824] # Arbitrary values for testing
+SEEDS = [0]
+CUDA_DEVICES = [
+ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
+]
+
+
+def norm(x, eps: float) -> torch.Tensor:
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
+ eps: float) -> torch.Tensor:
+ x = x.float()
+ return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
+ weight[2] * norm(x, eps) + bias).to(weight.dtype)
+
+
+def mul_poly_norm_all_naive(x: torch.Tensor, mul: torch.Tensor,
+ weight: torch.Tensor, bias: torch.Tensor,
+ eps: float) -> torch.Tensor:
+ return poly_norm(x, weight, bias, eps) * mul
+
+
+#use poly_norm kernel
+def mul_poly_norm_partial_naive(x: torch.Tensor, mul: torch.Tensor,
+ weight: torch.Tensor, bias: torch.Tensor,
+ eps: float) -> torch.Tensor:
+ return activation.poly_norm(x, weight, bias, eps) * mul
+
+
+@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("d", D)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_fused_mul_poly_norm(
+ num_tokens: int,
+ d: int,
+ dtype: torch.dtype,
+ seed: int,
+ device: str,
+) -> None:
+ random.seed(seed)
+ torch.manual_seed(seed)
+ torch.set_default_device(device)
+
+ x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
+ mul = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
+ weight = torch.randn(3, dtype=dtype, requires_grad=True)
+ bias = torch.randn(1, dtype=dtype, requires_grad=True)
+ eps = 1e-05
+
+ x.retain_grad()
+ mul.retain_grad()
+ weight.retain_grad()
+ bias.retain_grad()
+ # To separate gradient computation, clone the inputs
+
+ x_ref = x.detach().clone().requires_grad_(True)
+ mul_ref = mul.detach().clone().requires_grad_(True)
+ weight_ref = weight.detach().clone().requires_grad_(True)
+ bias_ref = bias.detach().clone().requires_grad_(True)
+
+ x_ref2 = x.detach().clone().requires_grad_(True)
+ mul_ref2 = mul.detach().clone().requires_grad_(True)
+ weight_ref2 = weight.detach().clone().requires_grad_(True)
+ bias_ref2 = bias.detach().clone().requires_grad_(True)
+
+ torch_fn = mul_poly_norm_all_naive
+ torch_fn2 = mul_poly_norm_partial_naive
+
+ op = activation.ops.fused_mul_poly_norm
+ fn = activation.fused_mul_poly_norm
+
+ layer = activation.layers.FusedMulPolyNorm(eps)
+ layer.weight = torch.nn.Parameter(weight)
+ layer.bias = torch.nn.Parameter(bias)
+
+ out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
+ opcheck(op, (out, x, mul, weight, bias, eps))
+
+ out = fn(x, mul, weight, bias, eps)
+ mod_out = layer(x, mul)
+ ref_out = torch_fn(x_ref, mul_ref, weight_ref, bias_ref, eps)
+ ref_out2 = torch_fn2(x_ref2, mul_ref2, weight_ref2, bias_ref2, eps)
+
+ # Mul amplifies small numeric differences between naive poly_norm and the kernel.
+ # When validating against all_naive, use a looser rtol/atol.
+ assert_close(out, ref_out, atol=0.01, rtol=0.01)
+ assert_close(out, ref_out2)
+ assert_close(mod_out, out, atol=0.0, rtol=0.0)
+
+ # test backward pass
+ out_grad = torch.randn_like(out)
+ out_grad = out_grad / out_grad.norm()
+
+ ref_out.backward(out_grad)
+ ref_out2.backward(out_grad)
+ mod_out.backward(out_grad)
+
+ assert_close(x.grad, x_ref.grad)
+ assert_close(x.grad, x_ref2.grad)
+ assert_close(mul.grad, mul_ref.grad)
+ assert_close(mul.grad, mul_ref2.grad)
+ assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
+ assert_close(layer.bias.grad, bias_ref2.grad, rtol=0.05)
+ assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
+ assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05)
diff --git a/tests/kernels/test_poly_norm.py b/tests/test_poly_norm.py
similarity index 92%
rename from tests/kernels/test_poly_norm.py
rename to tests/test_poly_norm.py
index dab3d26a722e7a1db6917e3745df165ed2f08ffb..3fb3129482251d4994a3ce2540cfa925ba49b700 100644
--- a/tests/kernels/test_poly_norm.py
+++ b/tests/test_poly_norm.py
@@ -8,10 +8,8 @@ import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float, torch.bfloat16, torch.half]
-# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
-# D = [512, 13824] # Arbitrary values for testing
-NUM_TOKENS = [7, 13] # Arbitrary values for testing
-D = [513] # Arbitrary values for testing
+NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing
+D = [1, 7, 512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
diff --git a/tests/kernels/test_rms_norm.py b/tests/test_rms_norm.py
similarity index 90%
rename from tests/kernels/test_rms_norm.py
rename to tests/test_rms_norm.py
index 20fef91a1b9a82c380f04c273e4b169636f98cca..f812ec9cf22fe0fcb23c040f8ffbf8624e344ef2 100644
--- a/tests/kernels/test_rms_norm.py
+++ b/tests/test_rms_norm.py
@@ -8,10 +8,8 @@ import activation
from .utils import assert_close, opcheck
DTYPES = [torch.float, torch.bfloat16, torch.half]
-# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
-# D = [512, 13824] # Arbitrary values for testing
-NUM_TOKENS = [7, 13] # Arbitrary values for testing
-D = [513] # Arbitrary values for testing
+NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing
+D = [1, 7, 512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
diff --git a/tests/kernels/utils.py b/tests/utils.py
similarity index 100%
rename from tests/kernels/utils.py
rename to tests/utils.py
diff --git a/torch-ext/activation/__init__.py b/torch-ext/activation/__init__.py
index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644
--- a/torch-ext/activation/__init__.py
+++ b/torch-ext/activation/__init__.py
@@ -2,8 +2,8 @@ import torch
from . import layers
from ._ops import ops
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
def poly_norm(
@@ -15,6 +15,16 @@ def poly_norm(
return PolyNormFunction.apply(x, weight, bias, eps)
+def fused_mul_poly_norm(
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps)
+
+
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor,
@@ -23,8 +33,20 @@ def rms_norm(
return RMSNormFunction.apply(x, weight, eps)
+def fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+) -> None:
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
+
+
__all__ = [
"poly_norm",
+ "fused_mul_poly_norm",
+ "rms_norm",
+ "fused_add_rms_norm",
"layers",
"ops",
]
diff --git a/torch-ext/activation/layers.py b/torch-ext/activation/layers.py
index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644
--- a/torch-ext/activation/layers.py
+++ b/torch-ext/activation/layers.py
@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from torch.nn import init
-from .poly_norm import PolyNormFunction
-from .rms_norm import RMSNormFunction
+from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
+from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
class PolyNorm(nn.Module):
@@ -28,6 +28,30 @@ class PolyNorm(nn.Module):
init.zeros_(self.bias)
+class FusedMulPolyNorm(nn.Module):
+
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mul: torch.Tensor,
+ ):
+ return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias,
+ self.eps)
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
@@ -46,3 +70,25 @@ class RMSNorm(nn.Module):
Resets parameters based on their initialization used in __init__.
"""
init.ones_(self.weight)
+
+
+class FusedAddRMSNorm(nn.Module):
+
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
+ self.eps = eps
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ ):
+ return FusedAddRMSNormFunction.apply(x, residual, self.weight,
+ self.eps)[0]
+
+ def reset_parameters(self) -> None:
+ """
+ Resets parameters based on their initialization used in __init__.
+ """
+ init.ones_(self.weight)
diff --git a/torch-ext/activation/poly_norm.py b/torch-ext/activation/poly_norm.py
index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644
--- a/torch-ext/activation/poly_norm.py
+++ b/torch-ext/activation/poly_norm.py
@@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function):
input, weight, eps)
return input_grad, weight_grad, bias_grad, None
+
+
+class FusedMulPolyNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, mul, weight, bias, eps):
+ output = torch.empty_like(input)
+ ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps)
+ return output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, output):
+ input, mul, weight, bias, eps = inputs
+ ctx.save_for_backward(input, mul, weight, bias)
+ ctx.eps = eps
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, output_grad):
+ input, mul, weight, bias = ctx.saved_tensors
+ eps = ctx.eps
+
+ input_grad = torch.empty_like(
+ input) if ctx.needs_input_grad[0] else None
+ mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
+ if ctx.needs_input_grad[3] else None)
+
+ ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad,
+ bias_grad, output_grad, input, mul,
+ weight, bias, eps)
+
+ return input_grad, mul_grad, weight_grad, bias_grad, None
diff --git a/torch-ext/activation/rms_norm.py b/torch-ext/activation/rms_norm.py
index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644
--- a/torch-ext/activation/rms_norm.py
+++ b/torch-ext/activation/rms_norm.py
@@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function):
weight, eps)
return input_grad, weight_grad, None
+
+
+# Inherit from Function
+class FusedAddRMSNormFunction(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(input, residual, weight, eps):
+ output = torch.empty_like(input)
+ add_output = torch.empty_like(input)
+ ops.fused_add_rms_norm(output, add_output, input, residual, weight,
+ eps)
+ return output, add_output
+
+ @staticmethod
+ # inputs is a Tuple of all of the inputs passed to forward.
+ # output is the output of the forward().
+ def setup_context(ctx, inputs, outputs):
+ _, _, weight, eps = inputs
+ _, add_output = outputs
+ ctx.mark_non_differentiable(add_output)
+ ctx.set_materialize_grads(False)
+ ctx.save_for_backward(weight, add_output)
+ ctx.eps = eps
+
+ # This function only needs one gradient
+ @staticmethod
+ def backward(ctx, output_grad, _):
+ weight, add_output = ctx.saved_tensors
+ eps = ctx.eps
+
+ if output_grad is None:
+ output_grad = torch.zeros_like(add_output)
+
+ need_in = ctx.needs_input_grad[0]
+ need_res = ctx.needs_input_grad[1]
+
+ grad = torch.empty_like(output_grad) if need_in or need_res else None
+
+ weight_grad = torch.empty_like(
+ weight) if ctx.needs_input_grad[2] else None
+
+ ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
+ weight, eps)
+ input_grad = grad if need_in else None
+ residual_grad = grad if need_res else None
+
+ return input_grad, residual_grad, weight_grad, None
diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp
index 74fb10647ac4d9c3704d2cbf41f235d2c6495e56..5c71a5ba1e025e276b32d6878defb0ab1f28ffd1 100644
--- a/torch-ext/torch_binding.cpp
+++ b/torch-ext/torch_binding.cpp
@@ -1,25 +1,50 @@
+#include "torch_binding.h"
+
#include
#include "registration.h"
-#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
- // Activation ops
+ // poly_norm
ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, "
"float eps) -> ()");
+ ops.impl("poly_norm", torch::kCUDA, &poly_norm);
+
ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! "
"bias_grad, Tensor output_grad, Tensor input, Tensor weight, float "
"eps) -> ()");
- ops.impl("poly_norm", torch::kCUDA, &poly_norm);
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
- // Activation ops
+ // rms_norm
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
+ ops.impl("rms_norm", torch::kCUDA, &rms_norm);
+
ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor "
"output_grad, Tensor input, Tensor weight, float eps) -> ()");
- ops.impl("rms_norm", torch::kCUDA, &rms_norm);
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
+
+ // fused_mul_poly_norm
+ ops.def("fused_mul_poly_norm(Tensor! out, Tensor input, Tensor mul, Tensor "
+ "weight, Tensor bias, "
+ "float eps) -> ()");
+ ops.impl("fused_mul_poly_norm", torch::kCUDA, &fused_mul_poly_norm);
+
+ ops.def("fused_mul_poly_norm_backward(Tensor! input_grad, Tensor! mul_grad, "
+ "Tensor! weight_grad, Tensor! "
+ "bias_grad, Tensor output_grad, Tensor input, Tensor mul, Tensor "
+ "weight, Tensor "
+ "bias, float eps) -> ()");
+ ops.impl("fused_mul_poly_norm_backward", torch::kCUDA,
+ &fused_mul_poly_norm_backward);
+
+ // fused_add_rms_norm
+ // fused_add_rms_norm_backward uses rms_norm_backward_kernel
+ ops.def(
+ "fused_add_rms_norm(Tensor! out, Tensor! add_out, Tensor input, Tensor "
+ "residual, Tensor "
+ "weight, float eps) -> ()");
+ ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h
index a64ae48061cf3b34405793564ebd1c9d76ffb006..b3629b1afbe07d99f6a60bd8c1f846fdc46ff8b0 100644
--- a/torch-ext/torch_binding.h
+++ b/torch-ext/torch_binding.h
@@ -17,3 +17,19 @@ void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
const torch::Tensor &output_grad,
const torch::Tensor &input, const torch::Tensor &weight,
double eps);
+
+void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input,
+ const torch::Tensor &mul, const torch::Tensor &weights,
+ const torch::Tensor &bias, double eps);
+void fused_mul_poly_norm_backward(
+ torch::Tensor &input_grad, torch::Tensor &mul_grad,
+ torch::Tensor &weight_grad, torch::Tensor &bias_grad,
+ const torch::Tensor &output_grad, const torch::Tensor &input,
+ const torch::Tensor &mul, const torch::Tensor &weight,
+ const torch::Tensor &bias, double eps);
+
+// fused_add_rms_norm_backward uses rms_norm_backward_kernel
+void fused_add_rms_norm(torch::Tensor &out, torch::Tensor &add_out,
+ const torch::Tensor &input,
+ const torch::Tensor &residual,
+ const torch::Tensor &weight, double eps);