diff --git a/README.md b/README.md index 0cad06b5de366188f3e251ef2ff54dc7e8bb2f87..219f3b3c19af202414aa8dbc0b6a885a05ffa7c7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,37 @@ Activation is a python package that contains custom CUDA-based activation kernel - Currently implemented - [PolyNorm](https://arxiv.org/html/2411.03884v1) - [RMSNorm](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html) + - **FusedAddRMSNorm** + + A fused operator that combines **residual addition** (`x + residual`) with **RMSNorm** in a single kernel. + - Instead of: + + ```python + y = x + residual + out = rms_norm(y, weight, eps) + ``` + + - Fused as: + + ```python + out = fused_add_rms_norm(x, residual, weight, eps) + ``` + + - **FusedMulPolyNorm** + + A fused operator that combines **PolyNorm** with an **element-wise multiplication** by a Tensor. + - Instead of: + + ```python + y = poly_norm(x, weight, bias, eps) + out = y * a + ``` + + - Fused as: + + ```python + out = fused_mul_poly_norm(x, a, weight, bias, eps) + ``` ## Usage @@ -28,18 +59,158 @@ print(poly_norm(x)) ``` ## Performance +- Test cases are from the Motif LLM +- The results can be reproduced using the provided benchmarking tools. +- For details on how to use the benchmarking tools, please refer to the [benchmarks README](./benchmarks/README.md). +- The benchmark results may show fluctuations, especially in the backward pass and when the dimension size is small. + +### RMSNorm + +#### H100 Results + +
+Forward Performance + +![RMSNorm Forward Performance](./benchmarks/plots/h100/rms/plot_rms-fwd-perf.png) + +
+ +
+Backward Performance + +![RMSNorm Backward Performance](./benchmarks/plots/h100/rms/plot_rms-bwd-perf.png) + +
+ +#### MI250 Results + +
+Forward Performance + +![RMSNorm Forward Performance](./benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png) + +
+ +
+Backward Performance + +![RMSNorm Backward Performance](./benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png) + +
+ +--- + +### FusedAddRMSNorm + +> [!NOTE] +> For fusion case performance, the **non-fused baseline** was implemented with our **custom kernels**. + +#### H100 Results + +
+Forward Performance + +![FusedAddRMSNorm Forward Performance](./benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png) + +
+ +
+Backward Performance + +![FusedAddRMSNorm Backward Performance](./benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png) + +
+ +#### MI250 Results + +
+Forward Performance + +![FusedAddRMSNorm Forward Performance](./benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png) + +
+ +
+Backward Performance + +![FusedAddRMSNorm Backward Performance](./benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png) + +
+ +--- ### PolyNorm -- Test cases are from the Motif LLM -- You can reproduce the results with: +#### H100 Results -```bash -cd tests -pytest --run-perf --do-plot -``` +
+Forward Performance + +![PolyNorm Forward Performance](./benchmarks/plots/h100/poly/plot_poly-fwd-perf.png) + +
+ +
+Backward Performance + +![PolyNorm Backward Performance](./benchmarks/plots/h100/poly/plot_poly-bwd-perf.png) + +
+ +#### MI250 Results + +
+Forward Performance + +![PolyNorm Forward Performance](./benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png) + +
+ +
+Backward Performance + +![PolyNorm Backward Performance](./benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png) + +
+ +--- + +### FusedMulPolyNorm + +> [!NOTE] +> For fusion case performance, the **non-fused baseline** was implemented with our **custom kernels**. + +#### H100 Results + +
+Forward Performance + +![FusedMulPolyNorm Forward Performance](./benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png) + +
+ +
+Backward Performance + +![FusedMulPolyNorm Backward Performance](./benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png) + +
+ +#### MI250 Results + +
+Forward Performance + +![FusedMulPolyNorm Forward Performance](./benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png) + +
+ +
+Backward Performance + +![FusedMulPolyNorm Backward Performance](./benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png) -![PolyNorm Performance](./tests/perf.png) +
## Pre-commit Hooks diff --git a/activation/block_reduce.h b/activation/block_reduce.h deleted file mode 100644 index 61c56e3b4f71646ddfeb5f879bc8169273c285f8..0000000000000000000000000000000000000000 --- a/activation/block_reduce.h +++ /dev/null @@ -1,21 +0,0 @@ -namespace motif { - -template -__device__ acc_t _block_reduce_sum(acc_t *shared, const float val, - const int d) { - // TODO: Optimize with warp-level primitives - __syncthreads(); - - shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f; - __syncthreads(); - for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) { - if (threadIdx.x < stride) { - shared[threadIdx.x] += shared[threadIdx.x + stride]; - } - __syncthreads(); - } - - return shared[0]; -} - -} // namespace motif diff --git a/activation/fused_add_rms_norm.cu b/activation/fused_add_rms_norm.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e73bd8f61fbd1ad59824398cf6b43e121071ea8 --- /dev/null +++ b/activation/fused_add_rms_norm.cu @@ -0,0 +1,157 @@ +#include +#include +#include +#include + +#include + +#include "assert_utils.h" +#include "atomic_utils.h" +#include "cuda_compat.h" +#include "dispatch_utils.h" + +namespace motif { + +template struct alignas(sizeof(type) * N) type_vec_t { + type data[N]; +}; + +template +__global__ std::enable_if_t<(width > 0)> +fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] + scalar_t *__restrict__ add_out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ residual, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { + using vec_t = type_vec_t; + + const int vec_d = d / width; + const int64_t vec_offset = blockIdx.x * vec_d; + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + const vec_t *__restrict__ residual_vec = + reinterpret_cast(residual); + vec_t *__restrict__ add_out_vec = reinterpret_cast(add_out); + acc_t sum_square = 0.0f; + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t res_vec = residual_vec[vec_offset + idx]; + vec_t add_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i] + res_vec.data[i]; + sum_square += x * x; + add_vec.data[i] = x; + } + add_out_vec[vec_offset + idx] = add_vec; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x); + + __shared__ acc_t s_scale; + + if (threadIdx.x == 0) { + s_scale = rsqrtf(sum_square / d + eps); + } + __syncthreads(); + + const vec_t *__restrict__ weight_vec = + reinterpret_cast(weight); + vec_t *__restrict__ output_vec = reinterpret_cast(out); + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = add_out_vec[vec_offset + idx]; + vec_t w_vec = weight_vec[idx]; + vec_t y_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i]; + acc_t w = w_vec.data[i]; + + y_vec.data[i] = w * x * s_scale; + } + output_vec[vec_offset + idx] = y_vec; + } +} + +template +__global__ std::enable_if_t<(width == 0)> +fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] + scalar_t *__restrict__ add_out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ residual, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { + const int64_t token_idx = blockIdx.x; + const int64_t vec_idx = threadIdx.x; + acc_t sum_square = 0.0f; + + for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) { + acc_t x = input[token_idx * d + idx] + residual[token_idx * d + idx]; + sum_square += x * x; + add_out[token_idx * d + idx] = x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x); + + __shared__ acc_t s_scale; + + if (vec_idx == 0) { + s_scale = rsqrtf(sum_square / d + eps); + } + __syncthreads(); + + for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) { + acc_t x = add_out[token_idx * d + idx]; + acc_t w = weight[idx]; + out[token_idx * d + idx] = w * x * s_scale; + } +} + +} // namespace motif + +#define LAUNCH_RMS_NORM(width) \ + MOTIF_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + motif::fused_add_rms_norm_kernel \ + <<>>( \ + out.data_ptr(), add_out.data_ptr(), \ + input.data_ptr(), residual.data_ptr(), \ + weight.data_ptr(), eps, d); \ + }); + +void fused_add_rms_norm(torch::Tensor &out, // [..., d] + torch::Tensor &add_out, // [..., d] + const torch::Tensor &input, // [..., d] + const torch::Tensor &residual, // [..., d] + const torch::Tensor &weight, // [d] + double eps) { + AssertTensorShapeEqual(input, residual, "input", "residual"); + AssertTensorShapeEqual(input, out, "input", "out"); + AssertTensorShapeEqual(input, add_out, "input", "result"); + AssertTensorNotNull(weight, "weight"); + // TODO shape check + + int d = input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (d % 8 == 0) { + LAUNCH_RMS_NORM(8); + } else { + LAUNCH_RMS_NORM(0); + } +} diff --git a/activation/fused_mul_poly_norm.cu b/activation/fused_mul_poly_norm.cu new file mode 100644 index 0000000000000000000000000000000000000000..ce35c1dd67be0fd4c41658b55fc433bd87d1038a --- /dev/null +++ b/activation/fused_mul_poly_norm.cu @@ -0,0 +1,642 @@ +#include +#include +#include +#include + +#include + +#include "assert_utils.h" +#include "atomic_utils.h" +#include "cuda_compat.h" +#include "dispatch_utils.h" + +namespace motif { + +template struct alignas(sizeof(type) * N) type_vec_t { + type data[N]; +}; + +struct SumOp { + __device__ float3 operator()(const float3 &a, const float3 &b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } +}; + +struct SumOp4 { + __device__ float4 operator()(const float4 &a, const float4 &b) const { + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); + } +}; + +template +__global__ std::enable_if_t<(width > 0)> +fused_mul_poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ mul, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const scalar_t *__restrict__ bias, // [1] + const float eps, const int d) { + using vec_t = type_vec_t; + + const int vec_d = d / width; + const int64_t vec_offset = blockIdx.x * vec_d; + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = x_vec.data[i]; + acc_t x2 = x1 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x4 * x2; + + sum2 += x2; + sum4 += x4; + sum6 += x6; + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; + + __shared__ acc_t s_bias; + + __shared__ acc_t s_w2_inv_std1; + __shared__ acc_t s_w1_inv_std2; + __shared__ acc_t s_w0_inv_std3; + + if (threadIdx.x == 0) { + acc_t w0 = weight[0]; + acc_t w1 = weight[1]; + acc_t w2 = weight[2]; + s_bias = bias[0]; + + s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2; + s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1; + s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0; + } + __syncthreads(); + + acc_t w2_inv_std1 = s_w2_inv_std1; + acc_t w1_inv_std2 = s_w1_inv_std2; + acc_t w0_inv_std3 = s_w0_inv_std3; + acc_t bias_reg = s_bias; + + vec_t *__restrict__ output_vec = reinterpret_cast(out); + const vec_t *__restrict__ mul_vec = reinterpret_cast(mul); + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t m_vec = mul_vec[vec_offset + idx]; + vec_t y_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = x_vec.data[i]; + scalar_t m = m_vec.data[i]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + scalar_t poly_norm_result = + x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg; + y_vec.data[i] = poly_norm_result * m; + } + output_vec[vec_offset + idx] = y_vec; + } +} + +template +__global__ std::enable_if_t<(width == 0)> +fused_mul_poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ mul, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const scalar_t *__restrict__ bias, // [1] + const float eps, const int d) { + const int64_t token_idx = blockIdx.x; + + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; + + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + acc_t x1 = input[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x4 * x2; + + sum2 += x2; + sum4 += x4; + sum6 += x6; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; + + __shared__ acc_t s_bias; + + __shared__ acc_t s_w2_inv_std1; + __shared__ acc_t s_w1_inv_std2; + __shared__ acc_t s_w0_inv_std3; + + if (threadIdx.x == 0) { + acc_t w0 = weight[0]; + acc_t w1 = weight[1]; + acc_t w2 = weight[2]; + s_bias = bias[0]; + + s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2; + s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1; + s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0; + } + __syncthreads(); + + acc_t w2_inv_std1 = s_w2_inv_std1; + acc_t w1_inv_std2 = s_w1_inv_std2; + acc_t w0_inv_std3 = s_w0_inv_std3; + acc_t bias_reg = s_bias; + + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + acc_t x1 = input[token_idx * d + idx]; + scalar_t m = mul[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + scalar_t poly_norm_result = + x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg; + out[token_idx * d + idx] = poly_norm_result * m; + } +} + +template +__global__ std::enable_if_t<(width > 0)> fused_mul_poly_norm_backward_kernel( + scalar_t *__restrict__ input_grad, // [..., d] + scalar_t *__restrict__ mul_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., 3] + acc_t *__restrict__ temp_bias_grad, // [..., 1] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ mul, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const scalar_t *__restrict__ bias, // [1] + const float eps, const int d) { + using vec_t = type_vec_t; + + const int vec_d = d / width; + const int64_t vec_offset = blockIdx.x * vec_d; + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + const vec_t *__restrict__ mul_vec = reinterpret_cast(mul); + const vec_t *__restrict__ output_grad_vec = + reinterpret_cast(output_grad); + + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; + + acc_t sum_dx1 = 0.0f; + acc_t sum_dx2 = 0.0f; + acc_t sum_dx3 = 0.0f; + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t dy_fused_vec = output_grad_vec[vec_offset + idx]; + vec_t m_vec = mul_vec[vec_offset + idx]; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = x_vec.data[i]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x3 * x3; + + sum2 += x2; + sum4 += x4; + sum6 += x6; + + acc_t dy = dy_fused_vec.data[i] * m_vec.data[i]; + + sum_dx1 += dy * x1; + sum_dx2 += dy * x2; + sum_dx3 += dy * x3; + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; + + float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3); + __syncthreads(); + float3 block_sum_dxs = + BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x); + + sum_dx1 = block_sum_dxs.x; + sum_dx2 = block_sum_dxs.y; + sum_dx3 = block_sum_dxs.z; + + __shared__ acc_t s_mean2; + __shared__ acc_t s_mean4; + __shared__ acc_t s_mean6; + __shared__ acc_t s_sdx1; + __shared__ acc_t s_sdx2; + __shared__ acc_t s_sdx3; + + const acc_t inv_d = acc_t(1) / d; + + if (threadIdx.x == 0) { + s_mean2 = sum2 * inv_d + eps; + s_mean4 = sum4 * inv_d + eps; + s_mean6 = sum6 * inv_d + eps; + + s_sdx1 = sum_dx1 * inv_d; + s_sdx2 = sum_dx2 * inv_d; + s_sdx3 = sum_dx3 * inv_d; + } + __syncthreads(); + + acc_t w0 = weight[0]; + acc_t w1 = weight[1]; + acc_t w2 = weight[2]; + acc_t bias_reg = bias[0]; + + acc_t mean2 = s_mean2; + acc_t mean4 = s_mean4; + acc_t mean6 = s_mean6; + acc_t sdx1 = s_sdx1; + acc_t sdx2 = s_sdx2; + acc_t sdx3 = s_sdx3; + + acc_t inv_std1 = rsqrtf(mean2); + acc_t inv_std2 = rsqrtf(mean4); + acc_t inv_std3 = rsqrtf(mean6); + + acc_t w2_inv_std1 = inv_std1 * w2; + acc_t w1_inv_std2 = inv_std2 * w1; + acc_t w0_inv_std3 = inv_std3 * w0; + + // inv_std / mean == powf(mean, -1.5) + acc_t c1 = w2_inv_std1 / mean2; + acc_t c2 = acc_t(2) * w1_inv_std2 / mean4; + acc_t c3 = acc_t(3) * w0_inv_std3 / mean6; + + acc_t sum_dy = 0; + acc_t sum_dw0 = 0; + acc_t sum_dw1 = 0; + acc_t sum_dw2 = 0; + + vec_t *__restrict__ input_grad_vec = reinterpret_cast(input_grad); + vec_t *__restrict__ mul_grad_vec = reinterpret_cast(mul_grad); + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t dy_fused_vec = output_grad_vec[vec_offset + idx]; + vec_t m_vec = mul_vec[vec_offset + idx]; + vec_t dx_vec; + vec_t dm_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = x_vec.data[i]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + acc_t dy = dy_fused_vec.data[i] * m_vec.data[i]; + + // For register optimization, the order of the following logic matters. + // The input_grad related logic must be placed at the very end. + sum_dy += dy; + sum_dw0 += dy * (x3 * inv_std3); + sum_dw1 += dy * (x2 * inv_std2); + sum_dw2 += dy * (x1 * inv_std1); + + if (mul_grad) { + scalar_t poly_norm_result = + x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg; + dm_vec.data[i] = poly_norm_result * dy_fused_vec.data[i]; + } + + if (input_grad) { + acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3); + acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2); + acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1); + dx_vec.data[i] = dx1 + dx2 + dx3; + } + } + + if (input_grad) { + input_grad_vec[vec_offset + idx] = dx_vec; + } + if (mul_grad) { + mul_grad_vec[vec_offset + idx] = dm_vec; + } + } + + using BlockReduce4 = cub::BlockReduce; + __shared__ typename BlockReduce4::TempStorage reduceStore4; + + float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2); + float4 block_sum_ds = + BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x); + + sum_dy = block_sum_ds.x; + sum_dw0 = block_sum_ds.y; + sum_dw1 = block_sum_ds.z; + sum_dw2 = block_sum_ds.w; + + if (threadIdx.x == 0) { + temp_bias_grad[blockIdx.x] = sum_dy; + temp_weight_grad[blockIdx.x * 3 + 0] = sum_dw0; + temp_weight_grad[blockIdx.x * 3 + 1] = sum_dw1; + temp_weight_grad[blockIdx.x * 3 + 2] = sum_dw2; + } +} + +template +__global__ std::enable_if_t<(width == 0)> fused_mul_poly_norm_backward_kernel( + scalar_t *__restrict__ input_grad, // [..., d] + scalar_t *__restrict__ mul_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., 3] + acc_t *__restrict__ temp_bias_grad, // [..., 1] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ mul, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const scalar_t *__restrict__ bias, // [1] + const float eps, const int d) { + const int64_t token_idx = blockIdx.x; + + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; + + acc_t sum_dx1 = 0.0f; + acc_t sum_dx2 = 0.0f; + acc_t sum_dx3 = 0.0f; + + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + acc_t dy = output_grad[token_idx * d + idx] * mul[token_idx * d + idx]; + + acc_t x1 = input[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x3 * x3; + + sum2 += x2; + sum4 += x4; + sum6 += x6; + + sum_dx1 += dy * x1; + sum_dx2 += dy * x2; + sum_dx3 += dy * x3; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; + + float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3); + __syncthreads(); + float3 block_sum_dxs = + BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x); + + sum_dx1 = block_sum_dxs.x; + sum_dx2 = block_sum_dxs.y; + sum_dx3 = block_sum_dxs.z; + + __shared__ acc_t s_mean2; + __shared__ acc_t s_mean4; + __shared__ acc_t s_mean6; + __shared__ acc_t s_sdx1; + __shared__ acc_t s_sdx2; + __shared__ acc_t s_sdx3; + + const acc_t inv_d = acc_t(1) / d; + + if (threadIdx.x == 0) { + s_mean2 = sum2 * inv_d + eps; + s_mean4 = sum4 * inv_d + eps; + s_mean6 = sum6 * inv_d + eps; + + s_sdx1 = sum_dx1 * inv_d; + s_sdx2 = sum_dx2 * inv_d; + s_sdx3 = sum_dx3 * inv_d; + } + __syncthreads(); + + acc_t w0 = weight[0]; + acc_t w1 = weight[1]; + acc_t w2 = weight[2]; + acc_t bias_reg = bias[0]; + + acc_t mean2 = s_mean2; + acc_t mean4 = s_mean4; + acc_t mean6 = s_mean6; + acc_t sdx1 = s_sdx1; + acc_t sdx2 = s_sdx2; + acc_t sdx3 = s_sdx3; + + acc_t inv_std1 = rsqrtf(mean2); + acc_t inv_std2 = rsqrtf(mean4); + acc_t inv_std3 = rsqrtf(mean6); + + acc_t w2_inv_std1 = inv_std1 * w2; + acc_t w1_inv_std2 = inv_std2 * w1; + acc_t w0_inv_std3 = inv_std3 * w0; + + // inv_std / mean == powf(mean, -1.5) + acc_t c1 = w2_inv_std1 / mean2; + acc_t c2 = acc_t(2) * w1_inv_std2 / mean4; + acc_t c3 = acc_t(3) * w0_inv_std3 / mean6; + + acc_t sum_dy = 0; + acc_t sum_dw0 = 0; + acc_t sum_dw1 = 0; + acc_t sum_dw2 = 0; + + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t dy_fused = output_grad[token_idx * d + idx]; + acc_t dy = dy_fused * mul[token_idx * d + idx]; + acc_t x1 = input[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + + if (input_grad) { + acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3); + acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2); + acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1); + input_grad[token_idx * d + idx] = dx1 + dx2 + dx3; + } + + if (mul_grad) { + scalar_t poly_norm_result = + x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg; + mul_grad[token_idx * d + idx] = poly_norm_result * dy_fused; + } + + sum_dy += dy; + sum_dw0 += dy * (x3 * inv_std3); + sum_dw1 += dy * (x2 * inv_std2); + sum_dw2 += dy * (x1 * inv_std1); + } + + using BlockReduce4 = cub::BlockReduce; + __shared__ typename BlockReduce4::TempStorage reduceStore4; + + float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2); + float4 block_sum_ds = + BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x); + + sum_dy = block_sum_ds.x; + sum_dw0 = block_sum_ds.y; + sum_dw1 = block_sum_ds.z; + sum_dw2 = block_sum_ds.w; + + if (threadIdx.x == 0) { + temp_bias_grad[token_idx] = sum_dy; + temp_weight_grad[token_idx * 3 + 0] = sum_dw0; + temp_weight_grad[token_idx * 3 + 1] = sum_dw1; + temp_weight_grad[token_idx * 3 + 2] = sum_dw2; + } +} + +} // namespace motif + +#define LAUNCH_FUSED_MUL_POLY_NORM(width) \ + MOTIF_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_mul_poly_norm_kernel", [&] { \ + motif::fused_mul_poly_norm_kernel \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + mul.data_ptr(), weight.data_ptr(), \ + bias.data_ptr(), eps, d); \ + }); + +void fused_mul_poly_norm(torch::Tensor &out, // [..., d] + const torch::Tensor &input, // [..., d] + const torch::Tensor &mul, // [..., d] + const torch::Tensor &weight, // [3] + const torch::Tensor &bias, // [1] + double eps) { + AssertTensorShapeEqual(input, out, "input", "out"); + AssertTensorShapeEqual(input, mul, "input", "mul"); + AssertTensorNotNull(weight, "weight"); + AssertTensorNotNull(bias, "bias"); + // TODO shape check + + int d = input.size(-1); + int64_t num_tokens = input.numel() / d; + dim3 grid(num_tokens); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (d % 8 == 0) { + LAUNCH_FUSED_MUL_POLY_NORM(8); + } else { + LAUNCH_FUSED_MUL_POLY_NORM(0); + } +} + +#define LAUNCH_POLY_NORM_BACKWARD(width) \ + MOTIF_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_mul_poly_norm_backward_kernel", [&] { \ + motif::fused_mul_poly_norm_backward_kernel \ + <<>>( \ + input_grad.data_ptr(), \ + mul_grad.data_ptr(), \ + temp_weight_grad.data_ptr(), \ + temp_bias_grad.data_ptr(), \ + output_grad.data_ptr(), input.data_ptr(), \ + mul.data_ptr(), weight.data_ptr(), \ + bias.data_ptr(), eps, d); \ + }); + +void fused_mul_poly_norm_backward(torch::Tensor &input_grad, // [..., d] + torch::Tensor &mul_grad, // [..., d] + torch::Tensor &weight_grad, // [3] + torch::Tensor &bias_grad, // [1] + const torch::Tensor &output_grad, // [..., d] + const torch::Tensor &input, // [..., d] + const torch::Tensor &mul, // [..., d] + const torch::Tensor &weight, // [3] + const torch::Tensor &bias, // [1] + double eps) { + AssertTensorShapeEqual(input, input_grad, "input", "input_grad"); + AssertTensorShapeEqual(input, output_grad, "input", "output_grad"); + AssertTensorShapeEqual(input, mul_grad, "input", "mul_grad"); + AssertTensorShapeEqual(input, mul, "input", "mul"); + AssertTensorNotNull(weight, "weight"); + // TODO shape check + // weight_grad, bias_grad, mul_grad and input_grad can be nullable + + int d = input.size(-1); + int64_t num_tokens = input.numel() / d; + dim3 grid(num_tokens); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); + + torch::Tensor temp_weight_grad = + torch::empty({num_tokens, 3}, input.options().dtype(torch::kFloat)); + torch::Tensor temp_bias_grad = + torch::empty({num_tokens, 1}, output_grad.options().dtype(torch::kFloat)); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (d % 8 == 0 && input.element_size() == 2) { + LAUNCH_POLY_NORM_BACKWARD(8); + } else if (d % 4 == 0 && input.element_size() == 4) { + LAUNCH_POLY_NORM_BACKWARD(4); + } else { + LAUNCH_POLY_NORM_BACKWARD(0); + } + + if (bias_grad.defined()) { + torch::Tensor acc = torch::empty_like(bias_grad, temp_bias_grad.options()); + at::sum_out(acc, temp_bias_grad, {0}); + bias_grad.copy_(acc); + } + + if (weight_grad.defined()) { + torch::Tensor acc = + torch::empty_like(weight_grad, temp_weight_grad.options()); + at::sum_out(acc, temp_weight_grad, {0}); + weight_grad.copy_(acc); + } +} diff --git a/activation/poly_norm.cu b/activation/poly_norm.cu index a1dcdec4accc3d14602a0fe32701135b24190f74..1977ef7260a66d0a9f66b0cca6aab32eca9d7c67 100644 --- a/activation/poly_norm.cu +++ b/activation/poly_norm.cu @@ -7,7 +7,6 @@ #include "assert_utils.h" #include "atomic_utils.h" -#include "block_reduce.h" #include "cuda_compat.h" #include "dispatch_utils.h" @@ -17,6 +16,18 @@ template struct alignas(sizeof(type) * N) type_vec_t { type data[N]; }; +struct SumOp { + __device__ float3 operator()(const float3 &a, const float3 &b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } +}; + +struct SumOp4 { + __device__ float4 operator()(const float4 &a, const float4 &b) const { + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); + } +}; + template __global__ std::enable_if_t<(width > 0)> poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] @@ -39,7 +50,7 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] #pragma unroll for (int i = 0; i < width; ++i) { - acc_t x1 = static_cast(x_vec.data[i]); + acc_t x1 = x_vec.data[i]; acc_t x2 = x1 * x1; acc_t x4 = x2 * x2; acc_t x6 = x4 * x2; @@ -50,14 +61,16 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] } } - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); - __syncthreads(); - sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); - __syncthreads(); - sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; __shared__ acc_t s_bias; @@ -90,14 +103,12 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] #pragma unroll for (int i = 0; i < width; ++i) { - acc_t x1 = static_cast(x_vec.data[i]); + acc_t x1 = x_vec.data[i]; acc_t x2 = x1 * x1; acc_t x3 = x2 * x1; - acc_t y = + y_vec.data[i] = x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg; - - y_vec.data[i] = static_cast(y); } output_vec[vec_offset + idx] = y_vec; } @@ -127,14 +138,16 @@ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] sum6 += x6; } - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); - __syncthreads(); - sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); - __syncthreads(); - sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; __shared__ acc_t s_bias; @@ -199,7 +212,7 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] #pragma unroll for (int i = 0; i < width; ++i) { - acc_t x1 = static_cast(x_vec.data[i]); + acc_t x1 = x_vec.data[i]; acc_t x2 = x1 * x1; acc_t x3 = x2 * x1; acc_t x4 = x2 * x2; @@ -209,7 +222,7 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] sum4 += x4; sum6 += x6; - acc_t dy = static_cast(dy_vec.data[i]); + acc_t dy = dy_vec.data[i]; sum_dx1 += dy * x1; sum_dx2 += dy * x2; @@ -217,22 +230,25 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] } } - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - __syncthreads(); - sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); - __syncthreads(); - sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); - __syncthreads(); - sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; + + float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3); __syncthreads(); - sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x); - __syncthreads(); - sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x); - __syncthreads(); - sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x); + float3 block_sum_dxs = + BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x); + + sum_dx1 = block_sum_dxs.x; + sum_dx2 = block_sum_dxs.y; + sum_dx3 = block_sum_dxs.z; __shared__ acc_t s_mean2; __shared__ acc_t s_mean4; @@ -288,16 +304,16 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] #pragma unroll for (int i = 0; i < width; ++i) { - acc_t x1 = static_cast(x_vec.data[i]); + acc_t x1 = x_vec.data[i]; acc_t x2 = x1 * x1; acc_t x3 = x2 * x1; - acc_t dy = static_cast(dy_vec.data[i]); + acc_t dy = dy_vec.data[i]; if (input_grad) { acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3); acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2); acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1); - dx_vec.data[i] = static_cast(dx1 + dx2 + dx3); + dx_vec.data[i] = dx1 + dx2 + dx3; } sum_dy += dy; @@ -311,13 +327,17 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] } } - sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x); - __syncthreads(); - sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x); - __syncthreads(); - sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x); - __syncthreads(); - sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x); + using BlockReduce4 = cub::BlockReduce; + __shared__ typename BlockReduce4::TempStorage reduceStore4; + + float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2); + float4 block_sum_ds = + BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x); + + sum_dy = block_sum_ds.x; + sum_dw0 = block_sum_ds.y; + sum_dw1 = block_sum_ds.z; + sum_dw2 = block_sum_ds.w; if (threadIdx.x == 0) { temp_bias_grad[blockIdx.x] = sum_dy; @@ -364,22 +384,25 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] sum_dx3 += dy * x3; } - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - __syncthreads(); - sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); - __syncthreads(); - sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); - __syncthreads(); - sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); + float3 thread_sums = make_float3(sum2, sum4, sum6); + float3 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + sum2 = block_sums.x; + sum4 = block_sums.y; + sum6 = block_sums.z; + + float3 thread_dxs = make_float3(sum_dx1, sum_dx2, sum_dx3); __syncthreads(); - sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x); - __syncthreads(); - sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x); - __syncthreads(); - sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x); + float3 block_sum_dxs = + BlockReduce(reduceStore).Reduce(thread_dxs, SumOp{}, blockDim.x); + + sum_dx1 = block_sum_dxs.x; + sum_dx2 = block_sum_dxs.y; + sum_dx3 = block_sum_dxs.z; __shared__ acc_t s_mean2; __shared__ acc_t s_mean4; @@ -445,13 +468,17 @@ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] sum_dw2 += dy * (x1 * inv_std1); } - sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x); - __syncthreads(); - sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x); - __syncthreads(); - sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x); - __syncthreads(); - sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x); + using BlockReduce4 = cub::BlockReduce; + __shared__ typename BlockReduce4::TempStorage reduceStore4; + + float4 thread_sum_ds = make_float4(sum_dy, sum_dw0, sum_dw1, sum_dw2); + float4 block_sum_ds = + BlockReduce4(reduceStore4).Reduce(thread_sum_ds, SumOp4{}, blockDim.x); + + sum_dy = block_sum_ds.x; + sum_dw0 = block_sum_ds.y; + sum_dw1 = block_sum_ds.z; + sum_dw2 = block_sum_ds.w; if (threadIdx.x == 0) { temp_bias_grad[token_idx] = sum_dy; diff --git a/activation/rms_norm.cu b/activation/rms_norm.cu index 9a2ffd0ac322acad34e56a1b2d53fde6178b714a..ee91d4dd8b2ae2844df52cab7bc57b054a0a97dd 100644 --- a/activation/rms_norm.cu +++ b/activation/rms_norm.cu @@ -7,18 +7,76 @@ #include "assert_utils.h" #include "atomic_utils.h" -#include "block_reduce.h" #include "cuda_compat.h" #include "dispatch_utils.h" namespace motif { -template -__global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] - const scalar_t *__restrict__ input, // [..., d] - const scalar_t *__restrict__ weight, // [d] - const float eps, const int d) { +template struct alignas(sizeof(type) * N) type_vec_t { + type data[N]; +}; +template +__global__ std::enable_if_t<(width > 0)> +rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { + using vec_t = type_vec_t; + + const int vec_d = d / width; + const int64_t vec_offset = blockIdx.x * vec_d; + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + acc_t sum_square = 0.0f; + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i]; + sum_square += x * x; + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x); + + __shared__ acc_t s_scale; + + if (threadIdx.x == 0) { + s_scale = rsqrtf(sum_square / d + eps); + } + __syncthreads(); + + const vec_t *__restrict__ weight_vec = + reinterpret_cast(weight); + vec_t *__restrict__ output_vec = reinterpret_cast(out); + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t w_vec = weight_vec[idx]; + vec_t y_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i]; + acc_t w = w_vec.data[i]; + + y_vec.data[i] = w * x * s_scale; + } + output_vec[vec_offset + idx] = y_vec; + } +} + +template +__global__ std::enable_if_t<(width == 0)> +rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { const int64_t token_idx = blockIdx.x; const int64_t vec_idx = threadIdx.x; acc_t sum_square = 0.0f; @@ -28,20 +86,123 @@ __global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] sum_square += x * x; } - __shared__ acc_t shared[BLOCK_SIZE]; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + sum_square = BlockReduce(reduceStore).Sum(sum_square, blockDim.x); + + __shared__ acc_t s_scale; + + if (vec_idx == 0) { + s_scale = rsqrtf(sum_square / d + eps); + } + __syncthreads(); - acc_t variance = - _block_reduce_sum(shared, sum_square, d) / d; - acc_t scale = rsqrt(variance + eps); for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) { acc_t x = input[token_idx * d + idx]; acc_t w = weight[idx]; - out[token_idx * d + idx] = w * x * scale; + out[token_idx * d + idx] = w * x * s_scale; } } +template +__global__ std::enable_if_t<(width > 0)> +rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., d] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { + using vec_t = type_vec_t; + using dw_vec_t = type_vec_t; + + const int64_t token_idx = blockIdx.x; + const int64_t vec_idx = threadIdx.x; + + const int vec_d = d / width; + const int64_t vec_offset = token_idx * vec_d; + + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + const vec_t *__restrict__ output_grad_vec = + reinterpret_cast(output_grad); + const vec_t *__restrict__ weight_vec = + reinterpret_cast(weight); + + acc_t d_sum = 0.0f; + acc_t sum_square = 0.0f; + + for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + vidx]; + vec_t dy_vec = output_grad_vec[vec_offset + vidx]; + vec_t w_vec = weight_vec[vidx]; -template -__global__ void +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i]; + acc_t dy = dy_vec.data[i]; + acc_t w = w_vec.data[i]; + d_sum += dy * x * w; + sum_square += x * x; + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + struct SumOp { + __device__ float2 operator()(const float2 &a, const float2 &b) const { + return make_float2(a.x + b.x, a.y + b.y); + } + }; + float2 thread_sums = make_float2(d_sum, sum_square); + float2 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); + + d_sum = block_sums.x; + sum_square = block_sums.y; + + __shared__ acc_t s_scale; + __shared__ acc_t s_dxx; + + if (threadIdx.x == 0) { + acc_t scale = rsqrtf(sum_square / d + eps); + s_dxx = d_sum * scale * scale * scale / d; + s_scale = scale; + } + __syncthreads(); + acc_t scale = s_scale; + acc_t dxx = s_dxx; + vec_t *__restrict__ input_grad_vec = reinterpret_cast(input_grad); + dw_vec_t *__restrict__ temp_weight_grad_vec = + reinterpret_cast(temp_weight_grad); + + for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + vidx]; + vec_t dy_vec = output_grad_vec[vec_offset + vidx]; + vec_t w_vec = weight_vec[vidx]; + + vec_t in_grad_vec; + dw_vec_t tw_grad_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x = x_vec.data[i]; + acc_t dy = dy_vec.data[i]; + acc_t w = w_vec.data[i]; + + if (input_grad) { + in_grad_vec.data[i] = scale * dy * w - dxx * x; + } + tw_grad_vec.data[i] = dy * x * scale; + } + + if (input_grad) { + input_grad_vec[vec_offset + vidx] = in_grad_vec; + } + temp_weight_grad_vec[vec_offset + vidx] = tw_grad_vec; + } +} + +template +__global__ std::enable_if_t<(width == 0)> rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] acc_t *__restrict__ temp_weight_grad, // [..., d] const scalar_t *__restrict__ output_grad, // [..., d] @@ -61,30 +222,55 @@ rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] sum_square += x * x; } - __shared__ acc_t shared[BLOCK_SIZE]; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + struct SumOp { + __device__ float2 operator()(const float2 &a, const float2 &b) const { + return make_float2(a.x + b.x, a.y + b.y); + } + }; + float2 thread_sums = make_float2(d_sum, sum_square); + float2 block_sums = + BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x); - d_sum = _block_reduce_sum(shared, d_sum, d); - acc_t variance = - _block_reduce_sum(shared, sum_square, d) / d; - acc_t scale = rsqrt(variance + eps); - acc_t scale_cubed = scale * scale * scale; - acc_t dxx = d_sum * scale_cubed / d; + d_sum = block_sums.x; + sum_square = block_sums.y; + + __shared__ acc_t s_scale; + __shared__ acc_t s_dxx; + + if (threadIdx.x == 0) { + acc_t scale = rsqrtf(sum_square / d + eps); + s_dxx = d_sum * scale * scale * scale / d; + s_scale = scale; + } + __syncthreads(); + + acc_t scale = s_scale; + acc_t dxx = s_dxx; for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) { acc_t x = input[token_idx * d + idx]; acc_t dy = output_grad[token_idx * d + idx]; acc_t w = weight[idx]; - input_grad[token_idx * d + idx] = scale * dy * w - dxx * x; - - if (temp_weight_grad) { - temp_weight_grad[token_idx * d + idx] = dy * x * scale; + if (input_grad) { + input_grad[token_idx * d + idx] = scale * dy * w - dxx * x; } + temp_weight_grad[token_idx * d + idx] = dy * x * scale; } } } // namespace motif +#define LAUNCH_RMS_NORM(width) \ + MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { \ + motif::rms_norm_kernel \ + <<>>(out.data_ptr(), \ + input.data_ptr(), \ + weight.data_ptr(), eps, d); \ + }); + void rms_norm(torch::Tensor &out, // [..., d] const torch::Tensor &input, // [..., d] const torch::Tensor &weight, // [d] @@ -93,27 +279,36 @@ void rms_norm(torch::Tensor &out, // [..., d] AssertTensorNotNull(weight, "weight"); // TODO shape check - constexpr int BLOCK_SIZE = 256; - int d = input.size(-1); int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - dim3 block(BLOCK_SIZE); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - motif::rms_norm_kernel - <<>>(out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), eps, d); - }); + if (d % 8 == 0) { + LAUNCH_RMS_NORM(8); + } else { + LAUNCH_RMS_NORM(0); + } } +#define LAUNCH_RMS_NORM_BWD(width) \ + MOTIF_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "rms_norm_backward_kernel", [&] { \ + motif::rms_norm_backward_kernel \ + <<>>(input_grad.data_ptr(), \ + temp_weight_grad.data_ptr(), \ + output_grad.data_ptr(), \ + input.data_ptr(), \ + weight.data_ptr(), eps, d); \ + }); + void rms_norm_backward(torch::Tensor &input_grad, // [..., d] - torch::Tensor &weight_grad, // [..., d] - const torch::Tensor &output_grad, // [d] - const torch::Tensor &input, // [d] + torch::Tensor &weight_grad, // [d] + const torch::Tensor &output_grad, // [..., d] + const torch::Tensor &input, // [..., d] const torch::Tensor &weight, // [d] double eps) { AssertTensorShapeEqual(input, input_grad, "input", "input_grad"); @@ -122,30 +317,27 @@ void rms_norm_backward(torch::Tensor &input_grad, // [..., d] // TODO shape check // weight_grad, input_grad can be nullable - constexpr int BLOCK_SIZE = 256; - int d = input.size(-1); int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - dim3 block(BLOCK_SIZE); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); torch::Tensor temp_weight_grad = torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - MOTIF_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "rms_norm_backward_kernel", [&] { - motif::rms_norm_backward_kernel - <<>>(input_grad.data_ptr(), - temp_weight_grad.data_ptr(), - output_grad.data_ptr(), - input.data_ptr(), - weight.data_ptr(), eps, d); - }); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (d % 8 == 0) { + LAUNCH_RMS_NORM_BWD(8); + } else { + LAUNCH_RMS_NORM_BWD(0); + } if (weight_grad.defined()) { - at::sum_out(weight_grad, temp_weight_grad, {0}); + torch::Tensor acc = + torch::empty_like(weight_grad, temp_weight_grad.options()); + at::sum_out(acc, temp_weight_grad, {0}); + weight_grad.copy_(acc); } } diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e6ad09a2b1b28b342aac1ee2c58f2d12ac6d01ff --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,35 @@ +# Benchmark Runner + +This script benchmarks **forward/backward performance** of several operations (`rms`, `add_rms`, `poly`, `mul_poly`). +Results can be saved as **CSV files** or **plots**. + +> **Note**
+> To run the benchmarks, you must select the appropriate Torch version along with the corresponding CUDA/ROCm build from within the `build` directory. +> +> **Example:** +> +> ```bash +> export PYTHONPATH=$PYTHONPATH:/activation/build/torch27-cxx11-cu128-x86_64-linux +> ``` + +## Usage + +```bash +python main.py --case [--plot] [--save-path ] +``` + +- `--case` (required): one of `rms`, `add_rms`, `poly`, `mul_poly` +- `--plot`: save plots instead of CSVs +- `--save-path`: output directory (default: `./configs/`) + +## Examples + +```bash +python main.py --case add_rms --save-path ./results/ +python main.py --case poly --plot --save-path ./plots/ +``` + +## Output + +- CSV: `-fwd-perf.csv`, `-bwd-perf.csv` +- Plots: `plot_-fwd-perf.png`, `plot_-bwd-perf.png` diff --git a/benchmarks/cases/__init__.py b/benchmarks/cases/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/benchmarks/cases/__init__.py @@ -0,0 +1 @@ + diff --git a/benchmarks/cases/add_rms.py b/benchmarks/cases/add_rms.py new file mode 100644 index 0000000000000000000000000000000000000000..5e055e197c2a9e8540c94b579e88db63824ce424 --- /dev/null +++ b/benchmarks/cases/add_rms.py @@ -0,0 +1,55 @@ +import torch +from common.diff_engine import DiffCase + +import activation + + +class FusedAddRMSNorm(torch.nn.Module): + + def __init__(self, d, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(d, dtype=dtype)) + self.eps = eps + + def forward(self, x, residual): + return activation.rms_norm((x + residual), self.weight, self.eps) + + +class AddRMS(DiffCase): + + def build_inputs(self, bs, sl, hidden, dtype, eps): + return { + "x": + torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True), + "residual": + torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True), + "weight": + torch.ones(hidden, dtype=dtype), + "dim": + hidden, + "eps": + eps, + "dtype": + dtype, + } + + def make_naive(self, I): + m = FusedAddRMSNorm(I["dim"], I["eps"], dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + return m + + def make_cuda(self, I): + m = activation.layers.FusedAddRMSNorm(I["dim"], + I["eps"], + dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + return m + + def forward(self, obj, I): + return obj(I["x"], I["residual"]) + + def grad_inputs(self, I): + return [I["x"], I["residual"]] + + +CASE = AddRMS() diff --git a/benchmarks/cases/mul_poly.py b/benchmarks/cases/mul_poly.py new file mode 100644 index 0000000000000000000000000000000000000000..48597c3cf321915910fdb0f02c607a03fbab4a3e --- /dev/null +++ b/benchmarks/cases/mul_poly.py @@ -0,0 +1,53 @@ +import torch +from common.diff_engine import DiffCase + +import activation + + +class FusedMulPolyNorm(torch.nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward(self, x, mul): + output = activation.poly_norm(x, self.weight, self.bias, self.eps) + return output * mul + + +class MulPoly(DiffCase): + + def build_inputs(self, bs, sl, hidden, dtype, eps): + return { + "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True), + "mul": torch.randn(bs, sl, hidden, dtype=dtype, + requires_grad=True), + "weight": torch.ones(3, dtype=dtype), + "bias": torch.ones(1, dtype=dtype), + "dim": hidden, + "eps": eps, + "dtype": dtype, + } + + def make_naive(self, I): + m = FusedMulPolyNorm(I["eps"], dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + m.bias = torch.nn.Parameter(I["bias"].detach().clone()) + return m + + def make_cuda(self, I): + m = activation.layers.FusedMulPolyNorm(I["eps"], dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + m.bias = torch.nn.Parameter(I["bias"].detach().clone()) + return m + + def forward(self, obj, I): + return obj(I["x"], I["mul"]) + + def grad_inputs(self, I): + return [I["x"], I["mul"]] + + +CASE = MulPoly() diff --git a/benchmarks/cases/poly.py b/benchmarks/cases/poly.py new file mode 100644 index 0000000000000000000000000000000000000000..00efede1e6a911e3f278343a9d49c36e619c7684 --- /dev/null +++ b/benchmarks/cases/poly.py @@ -0,0 +1,58 @@ +import torch +from common.diff_engine import DiffCase + +import activation + + +class PolyNorm(torch.nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def _norm(self, x): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + orig_dtype = x.dtype + x_float = x.to(torch.float32) + output = (self.weight[0] * self._norm(x_float**3) + + self.weight[1] * self._norm(x_float**2) + + self.weight[2] * self._norm(x_float) + self.bias) + return output.to(orig_dtype) + + +class Poly(DiffCase): + + def build_inputs(self, bs, sl, hidden, dtype, eps): + return { + "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True), + "weight": torch.ones(3, dtype=dtype), + "bias": torch.ones(1, dtype=dtype), + "dim": hidden, + "eps": eps, + "dtype": dtype, + } + + def make_naive(self, I): + m = PolyNorm(I["eps"], dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + m.bias = torch.nn.Parameter(I["bias"].detach().clone()) + return m + + def make_cuda(self, I): + m = activation.layers.PolyNorm(I["eps"], dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + m.bias = torch.nn.Parameter(I["bias"].detach().clone()) + return m + + def forward(self, obj, I): + return obj(I["x"]) + + def grad_inputs(self, I): + return [I["x"]] + + +CASE = Poly() diff --git a/benchmarks/cases/rms.py b/benchmarks/cases/rms.py new file mode 100644 index 0000000000000000000000000000000000000000..15331e8f42f5b4405b4d3a8317f3fa00b56c62ae --- /dev/null +++ b/benchmarks/cases/rms.py @@ -0,0 +1,35 @@ +import torch +from common.diff_engine import DiffCase + +import activation + + +class RMS(DiffCase): + + def build_inputs(self, bs, sl, hidden, dtype, eps): + return { + "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True), + "weight": torch.ones(hidden, dtype=dtype), + "dim": hidden, + "eps": eps, + "dtype": dtype, + } + + def make_naive(self, I): + m = torch.nn.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + return m + + def make_cuda(self, I): + m = activation.layers.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"]) + m.weight = torch.nn.Parameter(I["weight"].detach().clone()) + return m + + def forward(self, obj, I): + return obj(I["x"]) + + def grad_inputs(self, I): + return [I["x"]] + + +CASE = RMS() diff --git a/benchmarks/common/__init__.py b/benchmarks/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/benchmarks/common/__init__.py @@ -0,0 +1 @@ + diff --git a/benchmarks/common/bench_framework.py b/benchmarks/common/bench_framework.py new file mode 100644 index 0000000000000000000000000000000000000000..f24b8d8d6163a85ce347eeb3d3ad06f46f1c67cf --- /dev/null +++ b/benchmarks/common/bench_framework.py @@ -0,0 +1,220 @@ +import collections +import math +import re +from typing import Any, Dict, Sequence + +import torch +import triton + +from .diff_engine import DiffCase + + +def make_fwd_key(batch_size, seq_len, dim): + return f"forward : ({batch_size}, {seq_len}, {dim})" + + +def make_bwd_key(batch_size, seq_len, dim): + return f"backward : ({batch_size}, {seq_len}, {dim})" + + +def parse_config_string(config_str): + match = re.match(r"(\w+)\s*:\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", + config_str) + if not match: + raise ValueError(f"Invalid config string: {config_str}") + _, bs, sl, d = match.groups() + return int(bs), int(sl), int(d) + + +def make_fwd_benchmark_for_case( + *, + case: DiffCase, + configs: Sequence[tuple[int, int, int]], + plot_name: str, + ylabel: str = "us", + line_vals=("naive", "cuda", "speedup"), + line_names: Dict[str, str] | None = None, + dtype=torch.bfloat16, + eps: float = 1e-6, + time_unit_scale: float = 1000, +): + timings_ms = collections.defaultdict(dict) + line_vals = list(line_vals) + line_names = line_names or {v: v.title() for v in line_vals} + x_vals = [list(_) for _ in configs] + + @triton.testing.perf_report( + triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"], + x_vals=x_vals, + line_arg="provider", + line_vals=line_vals, + line_names=[line_names[v] for v in line_vals], + ylabel=ylabel, + plot_name=plot_name, + args={})) + def bench(dim, batch_size, seq_len, provider): + key = make_fwd_key(dim, batch_size, seq_len) + I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) + if provider == "speedup": + return timings_ms["naive"][key] / timings_ms["cuda"][key] + obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + run = lambda: case.forward(obj, I) + ms = triton.testing.do_bench(run) + timings_ms[provider][key] = ms + return time_unit_scale * ms + + return bench + + +def make_fwd_benchmark_plot_for_case( + *, + case: DiffCase, + configs: Sequence[tuple[int, int, int]], + plot_name: str, + ylabel: str = "Relative Speedup", + line_vals=("naive", "cuda"), + line_names: Dict[str, str] | None = None, + dtype=torch.bfloat16, + eps: float = 1e-6, +): + timings_ms = collections.defaultdict(dict) + spdup_ratio = list() + line_vals = list(line_vals) + line_names = line_names or {v: v.title() for v in line_vals} + x_vals = [make_fwd_key(*_) for _ in configs] + x_vals.append("Geometric Mean") + + @triton.testing.perf_report( + triton.testing.Benchmark(x_names=["config"], + x_vals=x_vals, + line_arg="provider", + line_vals=line_vals, + line_names=[line_names[v] for v in line_vals], + ylabel=ylabel, + plot_name=plot_name, + args={})) + def bench(config, provider): + if config == "Geometric Mean": + if provider == "cuda": + return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2) + else: + return 1.00 + batch_size, seq_len, dim = parse_config_string(config) + I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) + obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + run = lambda: case.forward(obj, I) + ms = triton.testing.do_bench(run) + timings_ms[provider][config] = ms + if provider == "cuda": + ratio = timings_ms["naive"][config] / timings_ms["cuda"][config] + spdup_ratio.append(ratio) + return round(ratio, 2) + else: + return 1.00 + + return bench + + +def make_bwd_benchmark_for_case( + *, + case: DiffCase, + configs: Sequence[tuple[int, int, int]], + plot_name: str, + ylabel: str = "us", + line_vals=("naive", "cuda", "speedup"), + line_names: Dict[str, str] | None = None, + dtype=torch.bfloat16, + eps: float = 1e-6, + time_unit_scale: float = 1000, +): + timings_ms = collections.defaultdict(dict) + line_vals = list(line_vals) + line_names = line_names or {v: v.title() for v in line_vals} + x_vals = [list(_) for _ in configs] + + @triton.testing.perf_report( + triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"], + x_vals=x_vals, + line_arg="provider", + line_vals=line_vals, + line_names=[line_names[v] for v in line_vals], + ylabel=ylabel, + plot_name=plot_name, + args={})) + def bench(dim, batch_size, seq_len, provider): + key = make_bwd_key(dim, batch_size, seq_len) + I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) + if provider == "speedup": + return timings_ms["naive"][key] / timings_ms["cuda"][key] + obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + y = case.forward(obj, I) + gin = list(case.grad_inputs(I)) + list(obj.parameters()) + g = torch.randn_like(y) + run = lambda: torch.autograd.grad(y, + gin, + g, + retain_graph=True, + create_graph=False, + allow_unused=False) + ms = triton.testing.do_bench(run) + timings_ms[provider][key] = ms + return time_unit_scale * ms + + return bench + + +def make_bwd_benchmark_plot_for_case( + *, + case: DiffCase, + configs: Sequence[tuple[int, int, int]], + plot_name: str, + ylabel: str = "Relative Speedup", + line_vals=("naive", "cuda"), + line_names: Dict[str, str] | None = None, + dtype=torch.bfloat16, + eps: float = 1e-6, +): + timings_ms = collections.defaultdict(dict) + spdup_ratio = list() + line_vals = list(line_vals) + line_names = line_names or {v: v.title() for v in line_vals} + x_vals = [make_bwd_key(*_) for _ in configs] + x_vals.append("Geometric Mean") + + @triton.testing.perf_report( + triton.testing.Benchmark(x_names=["config"], + x_vals=x_vals, + line_arg="provider", + line_vals=line_vals, + line_names=[line_names[v] for v in line_vals], + ylabel=ylabel, + plot_name=plot_name, + args={})) + def bench(config, provider): + if config == "Geometric Mean": + if provider == "cuda": + return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2) + else: + return 1.00 + batch_size, seq_len, dim = parse_config_string(config) + I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) + obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + y = case.forward(obj, I) + gin = list(case.grad_inputs(I)) + list(obj.parameters()) + g = torch.randn_like(y) + run = lambda: torch.autograd.grad(y, + gin, + g, + retain_graph=True, + create_graph=False, + allow_unused=False) + ms = triton.testing.do_bench(run) + timings_ms[provider][config] = ms + if provider == "cuda": + ratio = timings_ms["naive"][config] / timings_ms["cuda"][config] + spdup_ratio.append(ratio) + return round(ratio, 2) + else: + return 1.00 + + return bench diff --git a/benchmarks/common/diff_engine.py b/benchmarks/common/diff_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..276cd3900c34a69740cb35ec3183491055b32751 --- /dev/null +++ b/benchmarks/common/diff_engine.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Sequence + +import torch + + +class DiffCase(ABC): + + @abstractmethod + def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype, + eps: float) -> Dict[str, Any]: + ... + + @abstractmethod + def make_naive(self, I: Dict[str, Any]) -> Any: + ... + + @abstractmethod + def make_cuda(self, I: Dict[str, Any]) -> Any: + ... + + @abstractmethod + def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor: + ... + + @abstractmethod + def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]: + ... + + +def _clone_payload(d, device): + out = {} + for k, v in d.items(): + if isinstance(v, torch.Tensor): + t = v.detach().clone().to(device) + t.requires_grad_(v.requires_grad) + out[k] = t + else: + out[k] = v + return out + + +def _unit_grad_like(y): + g = torch.randn_like(y) + n = g.norm() + return g if n == 0 else g / n + + +def calculate_diff( + case: DiffCase, + *, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype=torch.bfloat16, + eps: float = 1e-6, + atol: float = 1e-2, + rtol: float = 1e-2, + device="cuda", +) -> None: + base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps) + I_n = _clone_payload(base, device) + I_c = _clone_payload(base, device) + obj_n = case.make_naive(I_n) + obj_c = case.make_cuda(I_c) + y_n = case.forward(obj_n, I_n) + y_c = case.forward(obj_c, I_c) + torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol) + gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters()) + gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters()) + g = _unit_grad_like(y_n).to(device) + ng = torch.autograd.grad(y_n, + gin_n, + g, + retain_graph=False, + create_graph=False, + allow_unused=False) + cg = torch.autograd.grad(y_c, + gin_c, + g, + retain_graph=False, + create_graph=False, + allow_unused=False) + torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol) + print("✅ forward + backward match") diff --git a/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png b/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..b596caf613cf58456d23286293dca72d747f377f Binary files /dev/null and b/benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png differ diff --git a/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png b/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..132e33a45d291c152b44398be9f5d75d2abaa232 Binary files /dev/null and b/benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png differ diff --git a/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png b/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..d8919a93e3d7be46a9d6cd57d832eff48377a637 Binary files /dev/null and b/benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png differ diff --git a/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png b/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..5191efe759849638f7ac9f0f619c8eb56a04975d Binary files /dev/null and b/benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png differ diff --git a/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png b/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..06d36211ff47c6c1a6feac15d5ea1c708c289ece Binary files /dev/null and b/benchmarks/plots/h100/poly/plot_poly-bwd-perf.png differ diff --git a/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png b/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..59f9cbd32aa6532bfcf5284f51a3d56be9d3da9d Binary files /dev/null and b/benchmarks/plots/h100/poly/plot_poly-fwd-perf.png differ diff --git a/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png b/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..95815d90490f6c5916fa4835271473ac2565b079 Binary files /dev/null and b/benchmarks/plots/h100/rms/plot_rms-bwd-perf.png differ diff --git a/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png b/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..5b86645cfecfbb0e9645527530a9fbd092f68b5f Binary files /dev/null and b/benchmarks/plots/h100/rms/plot_rms-fwd-perf.png differ diff --git a/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png b/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..f36820d99706b319eeda8f8a3ed5246383d2524f Binary files /dev/null and b/benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png differ diff --git a/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png b/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..dacca9cf397419f9af992675388f7299b35e30c0 Binary files /dev/null and b/benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png differ diff --git a/benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..3145730e039fd475d3fb990f0acaafc0dd1dd70e Binary files /dev/null and b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-bwd-perf.png differ diff --git a/benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..ba02f90b3811c6aa78f7d34aa856d8047a5cec34 Binary files /dev/null and b/benchmarks/plots/mi250/mul_poly/plot_mul_poly-fwd-perf.png differ diff --git a/benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png b/benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..327e7267c2bcff9842032d418855539e4757fefa Binary files /dev/null and b/benchmarks/plots/mi250/poly/plot_poly-bwd-perf.png differ diff --git a/benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png b/benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..71544c6ab53b849ab0770442cc7d409044529aa7 Binary files /dev/null and b/benchmarks/plots/mi250/poly/plot_poly-fwd-perf.png differ diff --git a/benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png b/benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..95bd8ee01a111de532b6fc2c9e957ac48df3efe2 Binary files /dev/null and b/benchmarks/plots/mi250/rms/plot_rms-bwd-perf.png differ diff --git a/benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png b/benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png new file mode 100644 index 0000000000000000000000000000000000000000..65e5fa7275861bd5f36ee81e576e27f11e1975b3 Binary files /dev/null and b/benchmarks/plots/mi250/rms/plot_rms-fwd-perf.png differ diff --git a/benchmarks/run_cases.py b/benchmarks/run_cases.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e1f746ccf14aed82182c482dc39c05fe03816e --- /dev/null +++ b/benchmarks/run_cases.py @@ -0,0 +1,143 @@ +import argparse +import glob +import importlib +import itertools +import os + +import torch +from common.bench_framework import (make_bwd_benchmark_for_case, + make_bwd_benchmark_plot_for_case, + make_fwd_benchmark_for_case, + make_fwd_benchmark_plot_for_case) +from common.diff_engine import DiffCase, calculate_diff + + +def make_title_tag(): + if torch.cuda.is_available(): + dev_name = torch.cuda.get_device_name(0) + else: + dev_name = "CPU" + + torch_ver = torch.__version__ + + return f"[{dev_name} | torch {torch_ver}]" + + +def plot_result(r_path): + import matplotlib.pyplot as plt + import pandas as pd + df = pd.read_csv(r_path + ".csv") + plt.figure(figsize=(12, 6)) + ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca()) + ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(), + fontsize=14, + fontweight="bold") + ax.set_ylabel("Relative Speedup", fontsize=14) + ax.set_xlabel("") + plt.xticks(rotation=45, fontsize=12, ha="right", rotation_mode="anchor") + for container in ax.containers: + labels = [f"x{v.get_height():.2f}" for v in container] + ax.bar_label(container, labels=labels, label_type="edge", fontsize=10) + plt.tight_layout() + plt.savefig(r_path + ".png", bbox_inches="tight") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--case", + choices=["rms", "add_rms", "poly", "mul_poly"], + required=True) + ap.add_argument("--plot", action="store_true") + ap.add_argument( + "--save-path", + type=str, + default="./configs/", + help="Path to save benchmark results", + ) + args = ap.parse_args() + + torch.set_default_device("cuda") + mod = importlib.import_module(f"cases.{args.case}") + case: DiffCase = mod.CASE + + calculate_diff( + case, + batch_size=2, + seq_len=128, + hidden_size=4096, + ) + + save_dir = os.path.join(args.save_path, args.case) + if args.plot: + batch_size_range = [1] + seq_length_range = [4096, 8192, 16384] + dim = [8192, 16384] if "poly" in args.case else [2048, 4096] + configs = list( + itertools.product(batch_size_range, seq_length_range, dim)) + plot_name = f"plot_{args.case}-fwd-perf" + bench = make_fwd_benchmark_plot_for_case( + case=case, + configs=configs, + plot_name=plot_name, + line_names={ + "naive": "Naive", + "cuda": "Cuda", + }, + ) + bench.run(print_data=True, save_path=save_dir) + plot_result(os.path.join(save_dir, plot_name)) + + plot_name = f"plot_{args.case}-bwd-perf" + bench = make_bwd_benchmark_plot_for_case( + case=case, + configs=configs, + plot_name=plot_name, + line_names={ + "naive": "Naive", + "cuda": "Cuda", + }, + ) + bench.run(print_data=True, save_path=save_dir) + plot_result(os.path.join(save_dir, plot_name)) + for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( + os.path.join(save_dir, "*.csv")): + os.remove(f) + else: + batch_size_range = [2**i for i in range(0, 4, 1)] + seq_length_range = [2**i for i in range(10, 14, 1)] + dim = [8192, 16384] if "poly" in args.case else [2048, 4096] + configs = list( + itertools.product(dim, batch_size_range, seq_length_range)) + + bench = make_fwd_benchmark_for_case( + case=case, + configs=configs, + plot_name=f"{args.case}-fwd-perf", + line_names={ + "naive": "Naive", + "cuda": "Cuda", + "speedup": "SpeedUp" + }, + ) + + bench.run(print_data=True, save_path=save_dir) + + bench = make_bwd_benchmark_for_case( + case=case, + configs=configs, + plot_name=f"{args.case}-bwd-perf", + line_names={ + "naive": "Naive", + "cuda": "Cuda", + "speedup": "SpeedUp" + }, + ) + + bench.run(print_data=True, save_path=save_dir) + for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( + os.path.join(save_dir, "*.png")): + os.remove(f) + + +if __name__ == "__main__": + main() diff --git a/build.toml b/build.toml index ce1953079d04f1265c38cdbc6080bbecc7705d1c..a61a661991131939392cb8ab771f57b320b68867 100644 --- a/build.toml +++ b/build.toml @@ -13,9 +13,10 @@ backend = "rocm" rocm-archs = [ "gfx90a", "gfx942" ] src = [ "activation/poly_norm.cu", + "activation/fused_mul_poly_norm.cu", "activation/rms_norm.cu", + "activation/fused_add_rms_norm.cu", "activation/cuda_compat.h", - "activation/block_reduce.h", "activation/dispatch_utils.h", "activation/assert_utils.h", "activation/atomic_utils.h", @@ -26,9 +27,10 @@ depends = [ "torch" ] backend = "cuda" src = [ "activation/poly_norm.cu", + "activation/fused_mul_poly_norm.cu", "activation/rms_norm.cu", + "activation/fused_add_rms_norm.cu", "activation/cuda_compat.h", - "activation/block_reduce.h", "activation/dispatch_utils.h", "activation/assert_utils.h", "activation/atomic_utils.h", diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..1b3674d54c044dddf2d037c1d3bac522bc19440c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d21a85bf21aa74f1281541e658acfd4f4326d902efe3578b059eccf054443284 +size 8089696 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..df3c3ae7785a3c30c36d900923c1dd7a349448db --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74d4955271509451b946495da75f69a0f978e7258b8303fe3c077e585c0d3e6a +size 8272456 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0de3488964fc7207148b7b9b62cc4db838e64c7b --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bf0d2ab5ff5520704e0b0c959b61d0043d360cfd4335950e69677873a87e436 +size 12792112 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..57361102c13046a6a1aab2f7125193ece35b21da --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:640322a8fac8fd9d8e9f195a3034c4ee0f81ee1acf897fd7c482a84ce47a1bec +size 4160688 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c703b3b19594e8b20ee5b4dc7692fbdad8079365 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1768d8d5072ac06d937cb5332988c6b3bfaa191f72d1369a22d2c577e9a3bca2 +size 8215280 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ecdc467a674247fe3898453418ce88a9983d08c5 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37a572bd877980ab8c0331ca5682191cb5a2b1f05bc69ea493a9e24f7728ba3f +size 12730840 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d6c8a74ea050b78cf9dcd4c43ac618094b0ca303 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f15919c4cac697cde550af16256e338472400e50df751e93622350c7f626bc8 +size 12726208 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..670a8291fdc208c690447600ee77449e1fac9929 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e72d4bb4459a5da96ca5eda1d305237a361140f0e25360e3d20326a22f1b6d47 +size 4165584 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c8f702b9ecfdc1c01dcdd2880d088458c4f11c2d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_20250907180255.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3325c2748cf7a070383068995078f93f440cc95fbed491d00bd414cdd851376 +size 4171472 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py index 11632044e1d56e11f7646a5a027b0aea5439e2af..a5ff861cc76e68ae4de5758b7acafa38f915e62a 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_f517c97_dirty -ops = torch.ops._activation_f517c97_dirty +from . import _activation_20250907180255 +ops = torch.ops._activation_20250907180255 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file + return f"_activation_20250907180255::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/tests/kernels/allclose_default.py b/tests/allclose_default.py similarity index 100% rename from tests/kernels/allclose_default.py rename to tests/allclose_default.py diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index eda6569be97917dc009f15b3ca4f11136708aa0c..0000000000000000000000000000000000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,144 +0,0 @@ -import logging - -import numpy as np -import plotly.graph_objects as go -import pytest - -from .kernels.test_poly_norm_perf import PERF_RESULTS, PerfResult - -logger = logging.getLogger(__name__) -DO_PLOT = False - - -def plot(perf_results: list[PerfResult]): - x_labels = [f"{r.type}, {r.shape}, {r.dtype}" for r in perf_results] - kernel_speedup = [r.speedup for r in perf_results] - torch_speedup = [1 for _ in perf_results] - - geo_mean = float(np.exp(np.mean(np.log(kernel_speedup)))) - x_labels.append("Geometric Mean") - kernel_speedup.append(geo_mean) - torch_speedup.append(1.0) - - fig = go.Figure() - - bar_width = 0.2 - fig.add_trace( - go.Bar( - x=x_labels, - y=kernel_speedup, - name="Activation", - marker_color="rgb(100, 100, 100)", - text=[f"x{v:.2f}" for v in kernel_speedup], - textfont=dict(size=14), - textposition="outside", - # width=[bar_width] * len(x_labels), - )) - - fig.add_trace( - go.Bar( - x=x_labels, - y=torch_speedup, - name="Torch", - marker_color="rgb(30, 30, 30)", - text=[f"x{v:.2f}" for v in torch_speedup], - textfont=dict(size=14), - textposition="outside", - # width=[bar_width] * len(x_labels), - )) - - fig.update_layout( - title=dict( - text= - "Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)", - font=dict(size=24), - ), - legend=dict( - x=0.01, - y=0.99, - xanchor="left", - yanchor="top", - bgcolor="rgba(0,0,0,0)", - bordercolor="black", - borderwidth=1, - ), - font=dict(size=16), - yaxis_title="Speedup (torch / activation)", - barmode="group", - bargroupgap=0, - bargap=0.2, - xaxis_tickangle=-45, - template="plotly_white", - yaxis_type="log", - shapes=[ - dict( - type="rect", - xref="x", - yref="paper", # y축 전체 범위 (0~1) - x0=-0.5, - x1=len(x_labels) - 0.5, - y0=0, - y1=1, - line=dict( - color="black", - width=1.5, - ), - fillcolor="rgba(0,0,0,0)", # 투명 배경 - layer="above", # bar 아래에 그리기 - ) - ], - ) - - output_file = "perf_result.html" - fig.write_html(output_file) - logger.info(f"Plotting performance results to {output_file}") - - -def pytest_addoption(parser): - parser.addoption("--run-perf", - action="store_true", - default=False, - help="Run perf tests") - parser.addoption("--do-plot", - action="store_true", - default=False, - help="Plot performance results") - - -@pytest.fixture -def do_plot(request): - return request.config.getoption("--do-plot") - - -def pytest_configure(config): - global DO_PLOT - DO_PLOT = config.getoption("--do-plot") - run_perf = config.getoption("--run-perf") - - if DO_PLOT and not run_perf: - raise ValueError( - "Cannot plot performance results without running performance tests. " - "Please use --run-perf option.") - - config.addinivalue_line("markers", - "perf: mark test as performance-related") - - -def pytest_collection_modifyitems(config, items): - run_perf = config.getoption("--run-perf") - - skip_perf = pytest.mark.skip(reason="need --run-perf option to run") - skip_normal = pytest.mark.skip( - reason="normal tests skipped when --run-perf is used") - for item in items: - if "perf" in item.keywords and not run_perf: - item.add_marker(skip_perf) - elif "perf" not in item.keywords and run_perf: - item.add_marker(skip_normal) - - -def pytest_sessionfinish(session, exitstatus) -> None: - if DO_PLOT: - plot(PERF_RESULTS) - else: - logger.info(PERF_RESULTS) diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tests/kernels/test_poly_norm_perf.py b/tests/kernels/test_poly_norm_perf.py deleted file mode 100644 index 88c35f34e8acb6cb835eedb7c1d6657e16b28da3..0000000000000000000000000000000000000000 --- a/tests/kernels/test_poly_norm_perf.py +++ /dev/null @@ -1,121 +0,0 @@ -import random -from dataclasses import dataclass - -import pytest -import torch - -import activation - -from .test_poly_norm import poly_norm -from .utils import assert_close - -CASES = [ - ((1, 2048, 8192), torch.bfloat16), - ((1, 2048, 16384), torch.bfloat16), - ((1, 16384, 8192), torch.bfloat16), - ((1, 16384, 16384), torch.bfloat16), -] -NUM_REP = 100 - - -@dataclass -class PerfResult: - type: str # forward or backward - shape: tuple - dtype: torch.dtype - kernel_time_ms: float - torch_time_ms: float - - @property - def speedup(self) -> float: - return self.torch_time_ms / self.kernel_time_ms - - -PERF_RESULTS: list[PerfResult] = [] - - -@pytest.mark.parametrize("cases", CASES) -@pytest.mark.perf -def test_poly_norm( - cases: tuple, - do_plot: bool, -) -> None: - random.seed(12345) - torch.manual_seed(12345) - - torch.set_default_device("cuda") - - shape, dtype = cases - x = torch.randn(shape, dtype=dtype, requires_grad=True) - weight = torch.randn(3, dtype=dtype, requires_grad=True) - bias = torch.randn(1, dtype=dtype, requires_grad=True) - eps = 1e-05 - - x.retain_grad() - weight.retain_grad() - bias.retain_grad() - # To separate gradient computation, clone the inputs - - x_ref = x.detach().clone().requires_grad_(True) - weight_ref = weight.detach().clone().requires_grad_(True) - bias_ref = bias.detach().clone().requires_grad_(True) - - torch_fn = poly_norm - layer = activation.layers.PolyNorm(eps) - layer.weight = torch.nn.Parameter(weight) - layer.bias = torch.nn.Parameter(bias) - - # Check correctness - mod_out = layer(x) - ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps) - assert_close(mod_out, ref_out) - - out_grad = torch.rand_like(ref_out) - out_grad = out_grad / out_grad.norm() - - ref_out.backward(out_grad, retain_graph=True) - mod_out.backward(out_grad, retain_graph=True) - - assert_close(x.grad, x_ref.grad) - assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05) - assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) - - def time_cuda(fn): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - for _ in range(5): - fn() - start.record() - for _ in range(NUM_REP): - fn() - end.record() - torch.cuda.synchronize() - return start.elapsed_time(end) / NUM_REP - - kernel_time_ms = time_cuda(lambda: layer(x)) - torch_fn_time = time_cuda( - lambda: torch_fn(x_ref, weight_ref, bias_ref, eps)) - - PERF_RESULTS.append( - PerfResult( - type="forward", - shape=shape, - dtype=dtype, - kernel_time_ms=kernel_time_ms, - torch_time_ms=torch_fn_time, - )) - - kernel_time_ms = time_cuda( - lambda: mod_out.backward(out_grad, retain_graph=True)) - torch_fn_time = time_cuda( - lambda: ref_out.backward(out_grad, retain_graph=True)) - - PERF_RESULTS.append( - PerfResult( - type="backward", - shape=shape, - dtype=dtype, - kernel_time_ms=kernel_time_ms, - torch_time_ms=torch_fn_time, - )) diff --git a/tests/perf.png b/tests/perf.png deleted file mode 100644 index 398b315aff1f0af712c8b0996bacfe78992c49ed..0000000000000000000000000000000000000000 --- a/tests/perf.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:12f88f9ac4511cb37f38a34e3572e4347bd0c857144a4aaf64bd5981d6b50877 -size 165982 diff --git a/tests/perf_result.html b/tests/perf_result.html deleted file mode 100644 index ccd99a78e8c7212caa63e79e7035477b94576b8c..0000000000000000000000000000000000000000 --- a/tests/perf_result.html +++ /dev/null @@ -1,3885 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/tests/test_fused_add_rms_norm.py b/tests/test_fused_add_rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..5486471c131e4543917cf7f399553e6dc71a7a72 --- /dev/null +++ b/tests/test_fused_add_rms_norm.py @@ -0,0 +1,101 @@ +import random + +import pytest +import torch + +import activation + +from .utils import assert_close, opcheck + +DTYPES = [torch.float, torch.bfloat16, torch.half] +NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing +D = [1, 7, 512, 13824] # Arbitrary values for testing +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def add_rms_norm_all_naive(x: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, eps: float) -> torch.Tensor: + return torch.nn.functional.rms_norm((x + residual), weight.shape, weight, + eps) + + +#use rms_norm kernel +def add_rms_norm_partial_naive(x: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, + eps: float) -> torch.Tensor: + return activation.rms_norm((x + residual), weight, eps) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_fused_add_rms_norm( + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.manual_seed(seed) + torch.set_default_device(device) + + x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) + residual = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) + weight = torch.randn(d, dtype=dtype, requires_grad=True) + eps = 1e-05 + + x.retain_grad() + residual.retain_grad() + weight.retain_grad() + # To separate gradient computation, clone the inputs + + x_ref = x.detach().clone().requires_grad_(True) + residual_ref = residual.detach().clone().requires_grad_(True) + weight_ref = weight.detach().clone().requires_grad_(True) + + x_ref2 = x.detach().clone().requires_grad_(True) + residual_ref2 = residual.detach().clone().requires_grad_(True) + weight_ref2 = weight.detach().clone().requires_grad_(True) + + torch_fn = add_rms_norm_all_naive + torch_fn2 = add_rms_norm_partial_naive + + op = activation.ops.fused_add_rms_norm + fn = activation.fused_add_rms_norm + + layer = activation.layers.FusedAddRMSNorm(d, eps) + layer.weight = torch.nn.Parameter(weight) + + out = torch.empty(x.shape, dtype=x.dtype, device=x.device) + add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device) + opcheck(op, (out, add_out, x, residual, weight, eps)) + + out = fn(x, residual, weight, eps) + mod_out = layer(x, residual) + ref_out = torch_fn(x_ref, residual_ref, weight_ref, eps) + ref_out2 = torch_fn2(x_ref2, residual_ref2, weight_ref2, eps) + + assert_close(out, ref_out) + assert_close(out, ref_out2) + assert_close(mod_out, out, atol=0.0, rtol=0.0) + + # test backward pass + out_grad = torch.randn_like(out) + out_grad = out_grad / out_grad.norm() + + ref_out.backward(out_grad) + ref_out2.backward(out_grad) + mod_out.backward(out_grad) + + assert_close(x.grad, x_ref.grad) + assert_close(x.grad, x_ref2.grad) + assert_close(residual.grad, residual_ref.grad) + assert_close(residual.grad, residual_ref2.grad) + assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) + assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05) diff --git a/tests/test_fused_mul_poly_norm.py b/tests/test_fused_mul_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..0f77c3d32bdbe0b87d51ffacc5280fec497cd81b --- /dev/null +++ b/tests/test_fused_mul_poly_norm.py @@ -0,0 +1,120 @@ +import random + +import pytest +import torch + +import activation + +from .utils import assert_close, opcheck + +DTYPES = [torch.float, torch.bfloat16, torch.half] +NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing +D = [1, 7, 512, 13824] # Arbitrary values for testing +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def norm(x, eps: float) -> torch.Tensor: + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, + eps: float) -> torch.Tensor: + x = x.float() + return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) + + weight[2] * norm(x, eps) + bias).to(weight.dtype) + + +def mul_poly_norm_all_naive(x: torch.Tensor, mul: torch.Tensor, + weight: torch.Tensor, bias: torch.Tensor, + eps: float) -> torch.Tensor: + return poly_norm(x, weight, bias, eps) * mul + + +#use poly_norm kernel +def mul_poly_norm_partial_naive(x: torch.Tensor, mul: torch.Tensor, + weight: torch.Tensor, bias: torch.Tensor, + eps: float) -> torch.Tensor: + return activation.poly_norm(x, weight, bias, eps) * mul + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_fused_mul_poly_norm( + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.manual_seed(seed) + torch.set_default_device(device) + + x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) + mul = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) + weight = torch.randn(3, dtype=dtype, requires_grad=True) + bias = torch.randn(1, dtype=dtype, requires_grad=True) + eps = 1e-05 + + x.retain_grad() + mul.retain_grad() + weight.retain_grad() + bias.retain_grad() + # To separate gradient computation, clone the inputs + + x_ref = x.detach().clone().requires_grad_(True) + mul_ref = mul.detach().clone().requires_grad_(True) + weight_ref = weight.detach().clone().requires_grad_(True) + bias_ref = bias.detach().clone().requires_grad_(True) + + x_ref2 = x.detach().clone().requires_grad_(True) + mul_ref2 = mul.detach().clone().requires_grad_(True) + weight_ref2 = weight.detach().clone().requires_grad_(True) + bias_ref2 = bias.detach().clone().requires_grad_(True) + + torch_fn = mul_poly_norm_all_naive + torch_fn2 = mul_poly_norm_partial_naive + + op = activation.ops.fused_mul_poly_norm + fn = activation.fused_mul_poly_norm + + layer = activation.layers.FusedMulPolyNorm(eps) + layer.weight = torch.nn.Parameter(weight) + layer.bias = torch.nn.Parameter(bias) + + out = torch.empty(x.shape, dtype=x.dtype, device=x.device) + opcheck(op, (out, x, mul, weight, bias, eps)) + + out = fn(x, mul, weight, bias, eps) + mod_out = layer(x, mul) + ref_out = torch_fn(x_ref, mul_ref, weight_ref, bias_ref, eps) + ref_out2 = torch_fn2(x_ref2, mul_ref2, weight_ref2, bias_ref2, eps) + + # Mul amplifies small numeric differences between naive poly_norm and the kernel. + # When validating against all_naive, use a looser rtol/atol. + assert_close(out, ref_out, atol=0.01, rtol=0.01) + assert_close(out, ref_out2) + assert_close(mod_out, out, atol=0.0, rtol=0.0) + + # test backward pass + out_grad = torch.randn_like(out) + out_grad = out_grad / out_grad.norm() + + ref_out.backward(out_grad) + ref_out2.backward(out_grad) + mod_out.backward(out_grad) + + assert_close(x.grad, x_ref.grad) + assert_close(x.grad, x_ref2.grad) + assert_close(mul.grad, mul_ref.grad) + assert_close(mul.grad, mul_ref2.grad) + assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05) + assert_close(layer.bias.grad, bias_ref2.grad, rtol=0.05) + assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) + assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05) diff --git a/tests/kernels/test_poly_norm.py b/tests/test_poly_norm.py similarity index 92% rename from tests/kernels/test_poly_norm.py rename to tests/test_poly_norm.py index dab3d26a722e7a1db6917e3745df165ed2f08ffb..3fb3129482251d4994a3ce2540cfa925ba49b700 100644 --- a/tests/kernels/test_poly_norm.py +++ b/tests/test_poly_norm.py @@ -8,10 +8,8 @@ import activation from .utils import assert_close, opcheck DTYPES = [torch.float, torch.bfloat16, torch.half] -# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing -# D = [512, 13824] # Arbitrary values for testing -NUM_TOKENS = [7, 13] # Arbitrary values for testing -D = [513] # Arbitrary values for testing +NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing +D = [1, 7, 512, 13824] # Arbitrary values for testing SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) diff --git a/tests/kernels/test_rms_norm.py b/tests/test_rms_norm.py similarity index 90% rename from tests/kernels/test_rms_norm.py rename to tests/test_rms_norm.py index 20fef91a1b9a82c380f04c273e4b169636f98cca..f812ec9cf22fe0fcb23c040f8ffbf8624e344ef2 100644 --- a/tests/kernels/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -8,10 +8,8 @@ import activation from .utils import assert_close, opcheck DTYPES = [torch.float, torch.bfloat16, torch.half] -# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing -# D = [512, 13824] # Arbitrary values for testing -NUM_TOKENS = [7, 13] # Arbitrary values for testing -D = [513] # Arbitrary values for testing +NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing +D = [1, 7, 512, 13824] # Arbitrary values for testing SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) diff --git a/tests/kernels/utils.py b/tests/utils.py similarity index 100% rename from tests/kernels/utils.py rename to tests/utils.py diff --git a/torch-ext/activation/__init__.py b/torch-ext/activation/__init__.py index fadce68f2fa0f130463f00a59c3436b822835e24..254ab917bab4ccd4a19327c1fe0bf96060f11217 100644 --- a/torch-ext/activation/__init__.py +++ b/torch-ext/activation/__init__.py @@ -2,8 +2,8 @@ import torch from . import layers from ._ops import ops -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction def poly_norm( @@ -15,6 +15,16 @@ def poly_norm( return PolyNormFunction.apply(x, weight, bias, eps) +def fused_mul_poly_norm( + x: torch.Tensor, + mul: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedMulPolyNormFunction.apply(x, mul, weight, bias, eps) + + def rms_norm( x: torch.Tensor, weight: torch.Tensor, @@ -23,8 +33,20 @@ def rms_norm( return RMSNormFunction.apply(x, weight, eps) +def fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0] + + __all__ = [ "poly_norm", + "fused_mul_poly_norm", + "rms_norm", + "fused_add_rms_norm", "layers", "ops", ] diff --git a/torch-ext/activation/layers.py b/torch-ext/activation/layers.py index 8ec01852a54649c04ba10f50aecb3f1a41576d18..156ea42df607e920731ad932d3a5b5d3a472c157 100644 --- a/torch-ext/activation/layers.py +++ b/torch-ext/activation/layers.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn import init -from .poly_norm import PolyNormFunction -from .rms_norm import RMSNormFunction +from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction +from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction class PolyNorm(nn.Module): @@ -28,6 +28,30 @@ class PolyNorm(nn.Module): init.zeros_(self.bias) +class FusedMulPolyNorm(nn.Module): + + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + mul: torch.Tensor, + ): + return FusedMulPolyNormFunction.apply(x, mul, self.weight, self.bias, + self.eps) + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) + init.zeros_(self.bias) + + class RMSNorm(nn.Module): def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): @@ -46,3 +70,25 @@ class RMSNorm(nn.Module): Resets parameters based on their initialization used in __init__. """ init.ones_(self.weight) + + +class FusedAddRMSNorm(nn.Module): + + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) + self.eps = eps + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ): + return FusedAddRMSNormFunction.apply(x, residual, self.weight, + self.eps)[0] + + def reset_parameters(self) -> None: + """ + Resets parameters based on their initialization used in __init__. + """ + init.ones_(self.weight) diff --git a/torch-ext/activation/poly_norm.py b/torch-ext/activation/poly_norm.py index e9f13435bd79b865dca42a4d84a9fd7e9f3ea479..8a0fd85f1835e02a36eb9184874d77dcad8221f9 100644 --- a/torch-ext/activation/poly_norm.py +++ b/torch-ext/activation/poly_norm.py @@ -37,3 +37,40 @@ class PolyNormFunction(torch.autograd.Function): input, weight, eps) return input_grad, weight_grad, bias_grad, None + + +class FusedMulPolyNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, mul, weight, bias, eps): + output = torch.empty_like(input) + ops.fused_mul_poly_norm(output, input, mul, weight, bias, eps) + return output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, output): + input, mul, weight, bias, eps = inputs + ctx.save_for_backward(input, mul, weight, bias) + ctx.eps = eps + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, output_grad): + input, mul, weight, bias = ctx.saved_tensors + eps = ctx.eps + + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + mul_grad = torch.empty_like(mul) if ctx.needs_input_grad[1] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[3] else None) + + ops.fused_mul_poly_norm_backward(input_grad, mul_grad, weight_grad, + bias_grad, output_grad, input, mul, + weight, bias, eps) + + return input_grad, mul_grad, weight_grad, bias_grad, None diff --git a/torch-ext/activation/rms_norm.py b/torch-ext/activation/rms_norm.py index 4ce81d593f3edd8f4d14ce41f822e8547188c049..7b4274f3a59c423a8662edf3bb8728a1daacb71f 100644 --- a/torch-ext/activation/rms_norm.py +++ b/torch-ext/activation/rms_norm.py @@ -35,3 +35,50 @@ class RMSNormFunction(torch.autograd.Function): weight, eps) return input_grad, weight_grad, None + + +# Inherit from Function +class FusedAddRMSNormFunction(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(input, residual, weight, eps): + output = torch.empty_like(input) + add_output = torch.empty_like(input) + ops.fused_add_rms_norm(output, add_output, input, residual, weight, + eps) + return output, add_output + + @staticmethod + # inputs is a Tuple of all of the inputs passed to forward. + # output is the output of the forward(). + def setup_context(ctx, inputs, outputs): + _, _, weight, eps = inputs + _, add_output = outputs + ctx.mark_non_differentiable(add_output) + ctx.set_materialize_grads(False) + ctx.save_for_backward(weight, add_output) + ctx.eps = eps + + # This function only needs one gradient + @staticmethod + def backward(ctx, output_grad, _): + weight, add_output = ctx.saved_tensors + eps = ctx.eps + + if output_grad is None: + output_grad = torch.zeros_like(add_output) + + need_in = ctx.needs_input_grad[0] + need_res = ctx.needs_input_grad[1] + + grad = torch.empty_like(output_grad) if need_in or need_res else None + + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[2] else None + + ops.rms_norm_backward(grad, weight_grad, output_grad, add_output, + weight, eps) + input_grad = grad if need_in else None + residual_grad = grad if need_res else None + + return input_grad, residual_grad, weight_grad, None diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp index 74fb10647ac4d9c3704d2cbf41f235d2c6495e56..5c71a5ba1e025e276b32d6878defb0ab1f28ffd1 100644 --- a/torch-ext/torch_binding.cpp +++ b/torch-ext/torch_binding.cpp @@ -1,25 +1,50 @@ +#include "torch_binding.h" + #include #include "registration.h" -#include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { - // Activation ops + // poly_norm ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, " "float eps) -> ()"); + ops.impl("poly_norm", torch::kCUDA, &poly_norm); + ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! " "bias_grad, Tensor output_grad, Tensor input, Tensor weight, float " "eps) -> ()"); - ops.impl("poly_norm", torch::kCUDA, &poly_norm); ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward); - // Activation ops + // rms_norm ops.def( "rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()"); + ops.impl("rms_norm", torch::kCUDA, &rms_norm); + ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor " "output_grad, Tensor input, Tensor weight, float eps) -> ()"); - ops.impl("rms_norm", torch::kCUDA, &rms_norm); ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward); + + // fused_mul_poly_norm + ops.def("fused_mul_poly_norm(Tensor! out, Tensor input, Tensor mul, Tensor " + "weight, Tensor bias, " + "float eps) -> ()"); + ops.impl("fused_mul_poly_norm", torch::kCUDA, &fused_mul_poly_norm); + + ops.def("fused_mul_poly_norm_backward(Tensor! input_grad, Tensor! mul_grad, " + "Tensor! weight_grad, Tensor! " + "bias_grad, Tensor output_grad, Tensor input, Tensor mul, Tensor " + "weight, Tensor " + "bias, float eps) -> ()"); + ops.impl("fused_mul_poly_norm_backward", torch::kCUDA, + &fused_mul_poly_norm_backward); + + // fused_add_rms_norm + // fused_add_rms_norm_backward uses rms_norm_backward_kernel + ops.def( + "fused_add_rms_norm(Tensor! out, Tensor! add_out, Tensor input, Tensor " + "residual, Tensor " + "weight, float eps) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h index a64ae48061cf3b34405793564ebd1c9d76ffb006..b3629b1afbe07d99f6a60bd8c1f846fdc46ff8b0 100644 --- a/torch-ext/torch_binding.h +++ b/torch-ext/torch_binding.h @@ -17,3 +17,19 @@ void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, const torch::Tensor &output_grad, const torch::Tensor &input, const torch::Tensor &weight, double eps); + +void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input, + const torch::Tensor &mul, const torch::Tensor &weights, + const torch::Tensor &bias, double eps); +void fused_mul_poly_norm_backward( + torch::Tensor &input_grad, torch::Tensor &mul_grad, + torch::Tensor &weight_grad, torch::Tensor &bias_grad, + const torch::Tensor &output_grad, const torch::Tensor &input, + const torch::Tensor &mul, const torch::Tensor &weight, + const torch::Tensor &bias, double eps); + +// fused_add_rms_norm_backward uses rms_norm_backward_kernel +void fused_add_rms_norm(torch::Tensor &out, torch::Tensor &add_out, + const torch::Tensor &input, + const torch::Tensor &residual, + const torch::Tensor &weight, double eps);