| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include <ATen/cuda/CUDAContext.h> |
| #include <c10/cuda/CUDAGuard.h> |
| #include <torch/all.h> |
|
|
| #ifndef USE_ROCM |
|
|
| #include <flashinfer/activation.cuh> |
|
|
| #include "utils.h" |
|
|
| #else |
| #include "hip/hip_act_and_mul.cuh" |
| #endif |
|
|
| |
| |
|
|
| namespace detail { |
|
|
| template <typename T> |
| __device__ __forceinline__ float to_f32(const T& x) { |
| #if USE_ROCM |
| return castToFloat(x); |
| #else |
| return static_cast<float>(x); |
| #endif |
| } |
|
|
| template <typename T> |
| __device__ __forceinline__ T from_f32(float f32) { |
| #if USE_ROCM |
| return castFromFloat<T>(f32); |
| #else |
| return static_cast<T>(f32); |
| #endif |
| } |
|
|
| } |
|
|
| template <typename T> |
| __device__ __forceinline__ T silu(const T& x) { |
| float f32_val = detail::to_f32(x); |
| return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val))); |
| } |
|
|
| template <typename T> |
| __device__ __forceinline__ T gelu(const T& x) { |
| constexpr float kAlpha = M_SQRT1_2; |
| float f32_val = detail::to_f32(x); |
| return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); |
| } |
|
|
| |
| template <typename T> |
| __device__ __forceinline__ T gelu_quick_act(const T& x) { |
| float f32_val = detail::to_f32(x); |
| return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f))); |
| } |
|
|
| template <typename T> |
| __device__ __forceinline__ T gelu_tanh(const T& x) { |
| constexpr float kAlpha = 0.044715f; |
| constexpr float kBeta = 0.7978845608028654f; |
| float f32_val = detail::to_f32(x); |
| const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); |
| return detail::from_f32<T>(f32_val * cdf); |
| } |
|
|
| void silu_and_mul(at::Tensor& out, at::Tensor& input) { |
| int d = input.size(-1) / 2; |
| int64_t num_tokens = input.numel() / input.size(-1); |
| dim3 grid(num_tokens); |
|
|
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); |
|
|
| DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { |
| uint32_t vec_size = 16 / sizeof(c_type); |
| dim3 block(std::min(d / vec_size, 1024U)); |
| #if USE_ROCM |
| sgl_hip::activation::act_and_mul_kernel<c_type, silu> |
| <<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d); |
| #else |
| flashinfer::activation::act_and_mul_kernel<c_type, silu> |
| <<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d); |
| #endif |
| return true; |
| }); |
| } |
|
|
| void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { |
| int d = input.size(-1) / 2; |
| int64_t num_tokens = input.numel() / input.size(-1); |
| dim3 grid(num_tokens); |
|
|
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); |
|
|
| DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { |
| uint32_t vec_size = 16 / sizeof(c_type); |
| dim3 block(std::min(d / vec_size, 1024U)); |
| #if USE_ROCM |
| sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh> |
| <<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d); |
| #else |
| flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh> |
| <<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d); |
| #endif |
| return true; |
| }); |
| } |
|
|
| void gelu_and_mul(at::Tensor& out, at::Tensor& input) { |
| int d = input.size(-1) / 2; |
| int64_t num_tokens = input.numel() / input.size(-1); |
| dim3 grid(num_tokens); |
|
|
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); |
|
|
| DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { |
| uint32_t vec_size = 16 / sizeof(c_type); |
| dim3 block(std::min(d / vec_size, 1024U)); |
| #if USE_ROCM |
| sgl_hip::activation::act_and_mul_kernel<c_type, gelu> |
| <<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d); |
| #else |
| flashinfer::activation::act_and_mul_kernel<c_type, gelu> |
| <<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d); |
| #endif |
|
|
| return true; |
| }); |
| } |
|
|
| #if USE_ROCM |
| void gelu_quick(at::Tensor& out, const at::Tensor& input) { |
| int d = input.size(-1); |
| int64_t num_tokens = input.numel() / input.size(-1); |
| dim3 grid(num_tokens); |
|
|
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); |
|
|
| DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { |
| uint32_t vec_size = 16 / sizeof(c_type); |
| dim3 block(std::min(d / vec_size, 1024U)); |
| sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act> |
| <<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d); |
|
|
| return true; |
| }); |
| } |
| #endif |
|
|