CharmAGX_G1 / core /kernels /fused_ops.cu
GeminiFan207's picture
Create fused_ops.cu
0f701ea verified
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <mma.h> // Tensor Core WMMA API
using namespace nvcuda;
// Define FP16 type for Tensor Cores
using half_t = __half;
// Thread block and warp sizes
#define BLOCK_SIZE 32
#define WARP_SIZE 32
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
// Fused sparse GEMM + ReLU kernel
__global__ void fused_sparse_gemm_relu_kernel(
const half_t* __restrict__ input, // Input tensor [batch_size, in_features]
const half_t* __restrict__ weight, // Weight tensor [out_features, in_features]
const half_t* __restrict__ mask, // Sparsity mask [out_features, in_features]
half_t* __restrict__ output, // Output tensor [batch_size, out_features]
const half_t* __restrict__ bias, // Bias tensor [out_features]
int batch_size, int in_features, int out_features)
{
// Shared memory for WMMA fragments
__shared__ half_t shmem_input[BLOCK_SIZE * WMMA_K];
__shared__ half_t shmem_weight[WMMA_M * WMMA_K];
// WMMA fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half_t> c_frag;
// Thread indices
int bx = blockIdx.x;
int by = blockIdx.y;
int tx = threadIdx.x;
int ty = threadIdx.y;
// Global indices
int row = by * WMMA_M + ty; // Output row
int col = bx * WMMA_N + tx; // Output col
// Compute tile offsets
int batch_offset = blockIdx.z * in_features; // Batch dimension
// Initialize accumulator
wmma::fill_fragment(c_frag, __float2half(0.0f));
// Loop over K dimension (in_features) in WMMA tiles
for (int k = 0; k < in_features; k += WMMA_K) {
// Load input into shared memory
if (ty < WMMA_K && row < batch_size) {
shmem_input[ty * BLOCK_SIZE + tx] = input[batch_offset + row * in_features + k + tx];
}
// Load sparse weight into shared memory (apply mask)
if (ty < WMMA_M && k + tx < in_features && row < out_features) {
half_t w = weight[row * in_features + k + tx];
half_t m = mask[row * in_features + k + tx];
shmem_weight[ty * WMMA_K + tx] = __hmul(w, m); // Apply sparsity mask
}
__syncthreads();
// Load WMMA fragments from shared memory
wmma::load_matrix_sync(a_frag, shmem_input, BLOCK_SIZE);
wmma::load_matrix_sync(b_frag, shmem_weight, WMMA_K);
// Perform Tensor Core matrix multiply-accumulate
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
// Store result with ReLU and bias
if (row < batch_size && col < out_features) {
half_t result = c_frag.x[ty * WMMA_N + tx];
result = __hadd(result, bias[col]); // Add bias
output[row * out_features + col] = __hgt(result, __float2half(0.0f)) ? result : __float2half(0.0f); // ReLU
}
}
// PyTorch binding
torch::Tensor fused_sparse_gemm_relu(
torch::Tensor input, // [batch_size, in_features]
torch::Tensor weight, // [out_features, in_features]
torch::Tensor mask, // [out_features, in_features]
torch::Tensor bias) // [out_features]
{
// Ensure inputs are FP16 and on CUDA
TORCH_CHECK(input.dtype() == torch::kFloat16, "Input must be FP16");
TORCH_CHECK(weight.dtype() == torch::kFloat16, "Weight must be FP16");
TORCH_CHECK(mask.dtype() == torch::kFloat16, "Mask must be FP16");
TORCH_CHECK(bias.dtype() == torch::kFloat16, "Bias must be FP16");
TORCH_CHECK(input.is_cuda(), "Input must be on CUDA");
TORCH_CHECK(weight.is_cuda(), "Weight must be on CUDA");
TORCH_CHECK(mask.is_cuda(), "Mask must be on CUDA");
TORCH_CHECK(bias.is_cuda(), "Bias must be on CUDA");
// Dimensions
int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = weight.size(0);
// Output tensor
auto output = torch::empty({batch_size, out_features},
torch::TensorOptions().dtype(torch::kFloat16).device(input.device()));
// Grid and block dimensions
dim3 block(BLOCK_SIZE, WMMA_M / WARP_SIZE);
dim3 grid((out_features + WMMA_N - 1) / WMMA_N,
(batch_size + WMMA_M - 1) / WMMA_M,
batch_size);
// Launch kernel
fused_sparse_gemm_relu_kernel<<<grid, block>>>(
(half_t*)input.data_ptr(),
(half_t*)weight.data_ptr(),
(half_t*)mask.data_ptr(),
(half_t*)output.data_ptr(),
(half_t*)bias.data_ptr(),
batch_size, in_features, out_features);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
TORCH_CHECK(false, "CUDA error: ", cudaGetErrorString(err));
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_sparse_gemm_relu", &fused_sparse_gemm_relu, "Fused sparse GEMM + ReLU with Tensor Cores");
}