burtenshaw HF Staff commited on
Commit
e0a93ee
·
verified ·
1 Parent(s): 5426fd3

Upload torch-ext/torch_binding.cpp with huggingface_hub

Browse files
Files changed (1) hide show
  1. torch-ext/torch_binding.cpp +104 -0
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * PyTorch C++ Bindings for Qwen3-8B CUDA Kernels
3
+ * Provides Python-callable wrappers for custom CUDA kernels.
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+ #include <torch/library.h>
8
+ #include <cuda_runtime.h>
9
+ #include <cuda_fp16.h>
10
+ #include <cuda_bf16.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+
13
+ #include "torch_binding.h"
14
+
15
+ #if __has_include("registration.h")
16
+ #include "registration.h"
17
+ #define QWEN3_KERNEL_BUILDER 1
18
+ #else
19
+ #define QWEN3_KERNEL_BUILDER 0
20
+ #endif
21
+
22
+ // External declarations for CUDA kernel launch functions
23
+ extern "C" {
24
+ void rmsnorm_forward_fp16(__half*, const __half*, const __half*, int, int, int, float, cudaStream_t);
25
+ void rmsnorm_forward_bf16(__nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, int, int, int, float, cudaStream_t);
26
+ void rmsnorm_forward_fp32(float*, const float*, const float*, int, int, int, float, cudaStream_t);
27
+ }
28
+
29
+ // ============================================================================
30
+ // RMSNorm Binding
31
+ // ============================================================================
32
+
33
+ void rmsnorm(
34
+ torch::Tensor& output,
35
+ const torch::Tensor& input,
36
+ const torch::Tensor& weight,
37
+ float eps
38
+ ) {
39
+ TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
40
+ TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor");
41
+ TORCH_CHECK(output.is_cuda(), "output must be a CUDA tensor");
42
+ TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
43
+ TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous");
44
+ TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
45
+ TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "input and weight must have the same dtype");
46
+ TORCH_CHECK(output.scalar_type() == input.scalar_type(), "output must match the input dtype");
47
+ TORCH_CHECK(input.dim() >= 1, "input must have at least one dimension");
48
+ TORCH_CHECK(weight.dim() == 1, "weight must be a 1D tensor");
49
+
50
+ const at::cuda::CUDAGuard device_guard(input.device());
51
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
52
+
53
+ const int ndim = input.dim();
54
+ const int hidden_size = input.size(ndim - 1);
55
+ const int64_t num_tokens = input.numel() / hidden_size;
56
+ TORCH_CHECK(weight.numel() == hidden_size, "weight size must match the hidden dimension");
57
+ TORCH_CHECK(output.sizes() == input.sizes(), "output must match the input shape");
58
+
59
+ const int batch_size = 1;
60
+ const int seq_len = num_tokens;
61
+
62
+ if (input.scalar_type() == at::kHalf) {
63
+ rmsnorm_forward_fp16(
64
+ reinterpret_cast<__half*>(output.data_ptr()),
65
+ reinterpret_cast<const __half*>(input.data_ptr()),
66
+ reinterpret_cast<const __half*>(weight.data_ptr()),
67
+ batch_size, seq_len, hidden_size, eps, stream
68
+ );
69
+ } else if (input.scalar_type() == at::kBFloat16) {
70
+ rmsnorm_forward_bf16(
71
+ reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),
72
+ reinterpret_cast<const __nv_bfloat16*>(input.data_ptr()),
73
+ reinterpret_cast<const __nv_bfloat16*>(weight.data_ptr()),
74
+ batch_size, seq_len, hidden_size, eps, stream
75
+ );
76
+ } else if (input.scalar_type() == at::kFloat) {
77
+ rmsnorm_forward_fp32(
78
+ reinterpret_cast<float*>(output.data_ptr()),
79
+ reinterpret_cast<const float*>(input.data_ptr()),
80
+ reinterpret_cast<const float*>(weight.data_ptr()),
81
+ batch_size, seq_len, hidden_size, eps, stream
82
+ );
83
+ } else {
84
+ TORCH_CHECK(false, "Unsupported dtype: ", input.scalar_type());
85
+ }
86
+ }
87
+
88
+ // ============================================================================
89
+ // Module Registration
90
+ // ============================================================================
91
+
92
+ #if QWEN3_KERNEL_BUILDER
93
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
94
+ ops.def("rmsnorm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
95
+ ops.impl("rmsnorm", torch::kCUDA, &rmsnorm);
96
+ }
97
+
98
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
99
+ #else
100
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
101
+ m.def("rmsnorm", &rmsnorm, "RMSNorm forward (CUDA)");
102
+ }
103
+ #endif
104
+