Upload torch-ext/torch_binding.cpp with huggingface_hub
Browse files- torch-ext/torch_binding.cpp +104 -0
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* PyTorch C++ Bindings for Qwen3-8B CUDA Kernels
|
| 3 |
+
* Provides Python-callable wrappers for custom CUDA kernels.
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
#include <torch/extension.h>
|
| 7 |
+
#include <torch/library.h>
|
| 8 |
+
#include <cuda_runtime.h>
|
| 9 |
+
#include <cuda_fp16.h>
|
| 10 |
+
#include <cuda_bf16.h>
|
| 11 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 12 |
+
|
| 13 |
+
#include "torch_binding.h"
|
| 14 |
+
|
| 15 |
+
#if __has_include("registration.h")
|
| 16 |
+
#include "registration.h"
|
| 17 |
+
#define QWEN3_KERNEL_BUILDER 1
|
| 18 |
+
#else
|
| 19 |
+
#define QWEN3_KERNEL_BUILDER 0
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
// External declarations for CUDA kernel launch functions
|
| 23 |
+
extern "C" {
|
| 24 |
+
void rmsnorm_forward_fp16(__half*, const __half*, const __half*, int, int, int, float, cudaStream_t);
|
| 25 |
+
void rmsnorm_forward_bf16(__nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, int, int, int, float, cudaStream_t);
|
| 26 |
+
void rmsnorm_forward_fp32(float*, const float*, const float*, int, int, int, float, cudaStream_t);
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// ============================================================================
|
| 30 |
+
// RMSNorm Binding
|
| 31 |
+
// ============================================================================
|
| 32 |
+
|
| 33 |
+
void rmsnorm(
|
| 34 |
+
torch::Tensor& output,
|
| 35 |
+
const torch::Tensor& input,
|
| 36 |
+
const torch::Tensor& weight,
|
| 37 |
+
float eps
|
| 38 |
+
) {
|
| 39 |
+
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
|
| 40 |
+
TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
|
| 41 |
+
TORCH_CHECK(output.is_cuda(), "output must be a CUDA tensor");
|
| 42 |
+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
|
| 43 |
+
TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous");
|
| 44 |
+
TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
|
| 45 |
+
TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "input and weight must have the same dtype");
|
| 46 |
+
TORCH_CHECK(output.scalar_type() == input.scalar_type(), "output must match the input dtype");
|
| 47 |
+
TORCH_CHECK(input.dim() >= 1, "input must have at least one dimension");
|
| 48 |
+
TORCH_CHECK(weight.dim() == 1, "weight must be a 1D tensor");
|
| 49 |
+
|
| 50 |
+
const at::cuda::CUDAGuard device_guard(input.device());
|
| 51 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 52 |
+
|
| 53 |
+
const int ndim = input.dim();
|
| 54 |
+
const int hidden_size = input.size(ndim - 1);
|
| 55 |
+
const int64_t num_tokens = input.numel() / hidden_size;
|
| 56 |
+
TORCH_CHECK(weight.numel() == hidden_size, "weight size must match the hidden dimension");
|
| 57 |
+
TORCH_CHECK(output.sizes() == input.sizes(), "output must match the input shape");
|
| 58 |
+
|
| 59 |
+
const int batch_size = 1;
|
| 60 |
+
const int seq_len = num_tokens;
|
| 61 |
+
|
| 62 |
+
if (input.scalar_type() == at::kHalf) {
|
| 63 |
+
rmsnorm_forward_fp16(
|
| 64 |
+
reinterpret_cast<__half*>(output.data_ptr()),
|
| 65 |
+
reinterpret_cast<const __half*>(input.data_ptr()),
|
| 66 |
+
reinterpret_cast<const __half*>(weight.data_ptr()),
|
| 67 |
+
batch_size, seq_len, hidden_size, eps, stream
|
| 68 |
+
);
|
| 69 |
+
} else if (input.scalar_type() == at::kBFloat16) {
|
| 70 |
+
rmsnorm_forward_bf16(
|
| 71 |
+
reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),
|
| 72 |
+
reinterpret_cast<const __nv_bfloat16*>(input.data_ptr()),
|
| 73 |
+
reinterpret_cast<const __nv_bfloat16*>(weight.data_ptr()),
|
| 74 |
+
batch_size, seq_len, hidden_size, eps, stream
|
| 75 |
+
);
|
| 76 |
+
} else if (input.scalar_type() == at::kFloat) {
|
| 77 |
+
rmsnorm_forward_fp32(
|
| 78 |
+
reinterpret_cast<float*>(output.data_ptr()),
|
| 79 |
+
reinterpret_cast<const float*>(input.data_ptr()),
|
| 80 |
+
reinterpret_cast<const float*>(weight.data_ptr()),
|
| 81 |
+
batch_size, seq_len, hidden_size, eps, stream
|
| 82 |
+
);
|
| 83 |
+
} else {
|
| 84 |
+
TORCH_CHECK(false, "Unsupported dtype: ", input.scalar_type());
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// ============================================================================
|
| 89 |
+
// Module Registration
|
| 90 |
+
// ============================================================================
|
| 91 |
+
|
| 92 |
+
#if QWEN3_KERNEL_BUILDER
|
| 93 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 94 |
+
ops.def("rmsnorm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
|
| 95 |
+
ops.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
| 99 |
+
#else
|
| 100 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 101 |
+
m.def("rmsnorm", &rmsnorm, "RMSNorm forward (CUDA)");
|
| 102 |
+
}
|
| 103 |
+
#endif
|
| 104 |
+
|