|
|
#include <cuda_fp16.h> |
|
|
#include <cuda_runtime.h> |
|
|
#include <torch/extension.h> |
|
|
#include <mma.h> |
|
|
|
|
|
using namespace nvcuda; |
|
|
|
|
|
|
|
|
using half_t = __half; |
|
|
|
|
|
|
|
|
#define BLOCK_SIZE 32 |
|
|
#define WARP_SIZE 32 |
|
|
#define WMMA_M 16 |
|
|
#define WMMA_N 16 |
|
|
#define WMMA_K 16 |
|
|
|
|
|
|
|
|
__global__ void fused_sparse_gemm_relu_kernel( |
|
|
const half_t* __restrict__ input, |
|
|
const half_t* __restrict__ weight, |
|
|
const half_t* __restrict__ mask, |
|
|
half_t* __restrict__ output, |
|
|
const half_t* __restrict__ bias, |
|
|
int batch_size, int in_features, int out_features) |
|
|
{ |
|
|
|
|
|
__shared__ half_t shmem_input[BLOCK_SIZE * WMMA_K]; |
|
|
__shared__ half_t shmem_weight[WMMA_M * WMMA_K]; |
|
|
|
|
|
|
|
|
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::row_major> a_frag; |
|
|
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::col_major> b_frag; |
|
|
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half_t> c_frag; |
|
|
|
|
|
|
|
|
int bx = blockIdx.x; |
|
|
int by = blockIdx.y; |
|
|
int tx = threadIdx.x; |
|
|
int ty = threadIdx.y; |
|
|
|
|
|
|
|
|
int row = by * WMMA_M + ty; |
|
|
int col = bx * WMMA_N + tx; |
|
|
|
|
|
|
|
|
int batch_offset = blockIdx.z * in_features; |
|
|
|
|
|
|
|
|
wmma::fill_fragment(c_frag, __float2half(0.0f)); |
|
|
|
|
|
|
|
|
for (int k = 0; k < in_features; k += WMMA_K) { |
|
|
|
|
|
if (ty < WMMA_K && row < batch_size) { |
|
|
shmem_input[ty * BLOCK_SIZE + tx] = input[batch_offset + row * in_features + k + tx]; |
|
|
} |
|
|
|
|
|
|
|
|
if (ty < WMMA_M && k + tx < in_features && row < out_features) { |
|
|
half_t w = weight[row * in_features + k + tx]; |
|
|
half_t m = mask[row * in_features + k + tx]; |
|
|
shmem_weight[ty * WMMA_K + tx] = __hmul(w, m); |
|
|
} |
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
wmma::load_matrix_sync(a_frag, shmem_input, BLOCK_SIZE); |
|
|
wmma::load_matrix_sync(b_frag, shmem_weight, WMMA_K); |
|
|
|
|
|
|
|
|
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); |
|
|
|
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
if (row < batch_size && col < out_features) { |
|
|
half_t result = c_frag.x[ty * WMMA_N + tx]; |
|
|
result = __hadd(result, bias[col]); |
|
|
output[row * out_features + col] = __hgt(result, __float2half(0.0f)) ? result : __float2half(0.0f); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
torch::Tensor fused_sparse_gemm_relu( |
|
|
torch::Tensor input, |
|
|
torch::Tensor weight, |
|
|
torch::Tensor mask, |
|
|
torch::Tensor bias) |
|
|
{ |
|
|
|
|
|
TORCH_CHECK(input.dtype() == torch::kFloat16, "Input must be FP16"); |
|
|
TORCH_CHECK(weight.dtype() == torch::kFloat16, "Weight must be FP16"); |
|
|
TORCH_CHECK(mask.dtype() == torch::kFloat16, "Mask must be FP16"); |
|
|
TORCH_CHECK(bias.dtype() == torch::kFloat16, "Bias must be FP16"); |
|
|
TORCH_CHECK(input.is_cuda(), "Input must be on CUDA"); |
|
|
TORCH_CHECK(weight.is_cuda(), "Weight must be on CUDA"); |
|
|
TORCH_CHECK(mask.is_cuda(), "Mask must be on CUDA"); |
|
|
TORCH_CHECK(bias.is_cuda(), "Bias must be on CUDA"); |
|
|
|
|
|
|
|
|
int batch_size = input.size(0); |
|
|
int in_features = input.size(1); |
|
|
int out_features = weight.size(0); |
|
|
|
|
|
|
|
|
auto output = torch::empty({batch_size, out_features}, |
|
|
torch::TensorOptions().dtype(torch::kFloat16).device(input.device())); |
|
|
|
|
|
|
|
|
dim3 block(BLOCK_SIZE, WMMA_M / WARP_SIZE); |
|
|
dim3 grid((out_features + WMMA_N - 1) / WMMA_N, |
|
|
(batch_size + WMMA_M - 1) / WMMA_M, |
|
|
batch_size); |
|
|
|
|
|
|
|
|
fused_sparse_gemm_relu_kernel<<<grid, block>>>( |
|
|
(half_t*)input.data_ptr(), |
|
|
(half_t*)weight.data_ptr(), |
|
|
(half_t*)mask.data_ptr(), |
|
|
(half_t*)output.data_ptr(), |
|
|
(half_t*)bias.data_ptr(), |
|
|
batch_size, in_features, out_features); |
|
|
|
|
|
cudaError_t err = cudaGetLastError(); |
|
|
if (err != cudaSuccess) { |
|
|
TORCH_CHECK(false, "CUDA error: ", cudaGetErrorString(err)); |
|
|
} |
|
|
|
|
|
return output; |
|
|
} |
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
|
m.def("fused_sparse_gemm_relu", &fused_sparse_gemm_relu, "Fused sparse GEMM + ReLU with Tensor Cores"); |
|
|
} |