| |
|
|
| #include <cstdint> |
| #include <cuda_runtime.h> |
| #include <torch/torch.h> |
| #include <type_traits> |
|
|
| |
| |
| |
|
|
| |
| template <typename T> |
| __device__ inline void atomicAdd16( |
| T *addr, |
| T val) { |
| |
| std::uintptr_t uaddr = reinterpret_cast<std::uintptr_t>(addr); |
| unsigned int *base = |
| reinterpret_cast<unsigned int *>(uaddr & ~std::uintptr_t(0x3)); |
| const bool hi_half = (uaddr & 0x2) != 0; |
|
|
| unsigned int old32 = *base, assumed; |
| do { |
| assumed = old32; |
|
|
| |
| unsigned short cur16 = hi_half ? (assumed >> 16) : (assumed & 0xFFFFu); |
|
|
| |
| T cur; |
| *reinterpret_cast<unsigned short *>(&cur) = cur16; |
| float f = static_cast<float>(cur) + static_cast<float>(val); |
|
|
| |
| T res = static_cast<T>(f); |
| unsigned short res16 = *reinterpret_cast<unsigned short *>(&res); |
|
|
| |
| unsigned int new32 = |
| hi_half ? ((assumed & 0x0000FFFFu) | |
| (static_cast<unsigned int>(res16) << 16)) |
| : ((assumed & 0xFFFF0000u) | static_cast<unsigned int>(res16)); |
|
|
| old32 = atomicCAS(base, assumed, new32); |
| } while (old32 != assumed); |
| } |
|
|
| |
| template <typename T> |
| __device__ inline void atomicAddT( |
| T *addr, |
| T val) { |
| if constexpr (std::is_same<T, float>::value) { |
| atomicAdd(addr, val); |
| } else if constexpr (std::is_same<T, double>::value) { |
| atomicAdd(addr, val); |
| } else { |
| |
| atomicAdd16(addr, val); |
| } |
| } |
|
|
| |
| |
| template <typename scalar_t> |
| __global__ void scatter_kernel( |
| const scalar_t *__restrict__ src, |
| const int *__restrict__ idx, |
| const int *__restrict__ bins, |
| const scalar_t *__restrict__ weights, |
| scalar_t *__restrict__ y, |
| int T, |
| int H, |
| int E, |
| int C, |
| int top_k) { |
| int e = blockIdx.x; |
| int i = blockIdx.y; |
| if (e >= E || i >= C) |
| return; |
|
|
| const int end = bins[e]; |
| const int start = (e == 0) ? 0 : bins[e - 1]; |
| const int n = end - start; |
|
|
| bool valid = (i < n); |
| int tok = 0; |
| if (valid) { |
| int flat = idx[start + i]; |
| tok = flat / top_k; |
| if (tok < 0 || tok >= T) |
| valid = false; |
| } |
| if (!valid) |
| return; |
|
|
| const scalar_t *src_row = src + ((size_t)e * C + i) * H; |
| scalar_t *y_row = y + (size_t)tok * H; |
|
|
| |
| scalar_t scale = (weights != nullptr) ? weights[start + i] : scalar_t(1.0); |
|
|
| int t = threadIdx.x; |
| for (int h = t; h < H; h += blockDim.x) { |
| atomicAddT(&y_row[h], src_row[h] * scale); |
| } |
| } |
|
|
| void scatter_cuda( |
| const torch::Tensor &src, |
| const torch::Tensor &indices, |
| const torch::Tensor &bins, |
| const torch::Tensor &weights, |
| torch::Tensor &y, |
| int64_t T, |
| int64_t E, |
| int64_t C, |
| int64_t top_k |
| ) { |
| const int64_t H = src.size(2); |
|
|
| |
| dim3 grid(E, C); |
| int threads = 256; |
|
|
| |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::kHalf, |
| at::kBFloat16, |
| src.scalar_type(), |
| "scatter_cuda", |
| ([&] { |
| using scalar_t_ = scalar_t; |
| scatter_kernel<scalar_t_><<<grid, threads>>>( |
| src.data_ptr<scalar_t_>(), |
| indices.data_ptr<int>(), |
| bins.data_ptr<int>(), |
| weights.defined() ? weights.data_ptr<scalar_t_>() : nullptr, |
| y.data_ptr<scalar_t_>(), |
| static_cast<int>(T), |
| static_cast<int>(H), |
| static_cast<int>(E), |
| static_cast<int>(C), |
| static_cast<int>(top_k)); |
| })); |
| } |
|
|