| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,383 +0,0 @@ |
| -#include "cuda_kernel.h" |
| - |
| -////////////////////////////////////////////////////////////////////////////////////////////////// |
| -////////////////////////////////////////////////////////////////////////////////////////////////// |
| - |
| -__global__ void index_max_cuda_kernel( |
| - float *index_vals, // [batch_size, 32, num_block] |
| - int *indices, // [batch_size, num_block] |
| - float *max_vals, // [batch_size, A_num_block * 32] |
| - float *max_vals_scatter, // [batch_size, 32, num_block] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long num_block |
| -) { |
| - |
| - long batch_idx = blockIdx.x; |
| - |
| - long thread_idx = threadIdx.x; |
| - long num_thread = blockDim.x; |
| - |
| - extern __shared__ float buffer[]; |
| - int *max_buffer = (int*)buffer; |
| - |
| - for (int i = 0; i < A_num_block * 32; i = i + num_thread) { |
| - int idx = i + thread_idx; |
| - if (idx < A_num_block * 32) { |
| - max_buffer[idx] = -1e8; |
| - } |
| - } |
| - __syncthreads(); |
| - |
| - int *indices_pt = &indices[batch_idx * num_block]; |
| - float *index_vals_pt = &index_vals[batch_idx * num_block * 32]; |
| - |
| - for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) { |
| - int idx = idx_start + thread_idx; |
| - int A_block_idx = indices_pt[idx % num_block] / B_num_block; |
| - atomicMax(&max_buffer[A_block_idx * 32 + idx / num_block], (int)(index_vals_pt[idx] * 1000)); |
| - } |
| - __syncthreads(); |
| - |
| - float *max_vals_pt = &max_vals[batch_idx * A_num_block * 32]; |
| - for (int i = 0; i < A_num_block * 32; i = i + num_thread) { |
| - int idx = i + thread_idx; |
| - if (idx < A_num_block * 32) { |
| - max_vals_pt[idx] = (float)max_buffer[idx] / 1000.; |
| - } |
| - } |
| - |
| - float *max_vals_scatter_pt = &max_vals_scatter[batch_idx * num_block * 32]; |
| - for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) { |
| - int idx = idx_start + thread_idx; |
| - int A_block_idx = indices_pt[idx % num_block] / B_num_block; |
| - max_vals_scatter_pt[idx] = (float)max_buffer[A_block_idx * 32 + idx / num_block] / 1000.; |
| - } |
| - |
| -} |
| - |
| -__global__ void mm_to_sparse_cuda_kernel( |
| - float *dense_A, // [batch_size, A_num_block, dim, 32] |
| - float *dense_B, // [batch_size, B_num_block, dim, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *sparse_C, // [batch_size, num_block, 32, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long dim, |
| - long num_block |
| -) { |
| - |
| - long batch_idx = blockIdx.y; |
| - long block_idx = blockIdx.x * blockDim.y + threadIdx.y; |
| - |
| - long thread_idx = threadIdx.x; |
| - |
| - __shared__ float buffer[4096]; |
| - float *A_buffer = &buffer[threadIdx.y * 1024]; // [2, 8, 32] |
| - float *B_buffer = &buffer[threadIdx.y * 1024 + 512]; // [2, 8, 32] |
| - |
| - long batch_idx__block_idx = batch_idx * num_block + block_idx; |
| - |
| - long AB_block_idx = indices[batch_idx__block_idx]; |
| - float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * dim * 32]; |
| - float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * dim * 32]; |
| - |
| - int reg_1_idx = thread_idx / 8; // [0000000011111111222222223333333344444444555555556666666677777777] |
| - int reg_2_idx = thread_idx % 8; // [0123456701234567012345670123456701234567012345670123456701234567] |
| - |
| - float reg_1[8]; |
| - float reg_2[8]; |
| - |
| - float reg_array[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - A_buffer[i * 64 + thread_idx] = dense_A_pt[i * 64 + thread_idx]; |
| - B_buffer[i * 64 + thread_idx] = dense_B_pt[i * 64 + thread_idx]; |
| - } |
| - |
| - __syncthreads(); |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - reg_1[i] = A_buffer[reg_1_idx * 4 + i]; |
| - reg_2[i] = B_buffer[reg_2_idx * 4 + i]; |
| - } |
| - |
| - for (int dim_stride = 1; dim_stride < (dim / 8); dim_stride++) { |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - A_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_A_pt[dim_stride * 256 + i * 64 + thread_idx]; |
| - B_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_B_pt[dim_stride * 256 + i * 64 + thread_idx]; |
| - } |
| - |
| - #pragma unroll |
| - for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) { |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_1_idx * 4 + i]; |
| - reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_2_idx * 4 + i]; |
| - } |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j]; |
| - } |
| - } |
| - } |
| - |
| - __syncthreads(); |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - reg_1[i] = A_buffer[(dim_stride % 2) * 256 + reg_1_idx * 4 + i]; |
| - reg_2[i] = B_buffer[(dim_stride % 2) * 256 + reg_2_idx * 4 + i]; |
| - } |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j]; |
| - } |
| - } |
| - |
| - } |
| - |
| - #pragma unroll |
| - for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) { |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[256 + mini_dim_idx * 32 + reg_1_idx * 4 + i]; |
| - reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[256 + mini_dim_idx * 32 + reg_2_idx * 4 + i]; |
| - } |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j]; |
| - } |
| - } |
| - } |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j]; |
| - } |
| - } |
| - __syncthreads(); |
| - |
| - float *C_buffer = &buffer[threadIdx.y * 1024]; // [32, 32] |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - C_buffer[(reg_2_idx * 4 + j) * 32 + reg_1_idx * 4 + i] = reg_array[i * 4 + j]; |
| - } |
| - } |
| - __syncthreads(); |
| - |
| - float *sparse_C_pt = &sparse_C[batch_idx__block_idx * 1024]; |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 16; i++) { |
| - sparse_C_pt[i * 64 + thread_idx] = C_buffer[i * 64 + thread_idx]; |
| - } |
| - |
| -} |
| - |
| -__global__ void sparse_dense_mm_cuda_kernel( |
| - float *sparse_A, // [batch_size, num_block, 32, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *dense_B, // [batch_size, B_num_block, dim, 32] |
| - float *dense_C, // [batch_size, A_num_block, dim, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long dim, |
| - long num_block |
| -) { |
| - |
| - long batch_idx = blockIdx.y; |
| - long block_idx = blockIdx.x * blockDim.y + threadIdx.y; |
| - |
| - long thread_idx = threadIdx.x; |
| - |
| - __shared__ float buffer[6144]; |
| - float *A_buffer = &buffer[threadIdx.y * 3072]; // [32, 32] |
| - float *B_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [32, 64] |
| - |
| - long batch_idx__block_idx = batch_idx * num_block + block_idx; |
| - |
| - float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024]; |
| - #pragma unroll |
| - for (int i = 0; i < 8; i++) { |
| - A_buffer[i * 128 + thread_idx] = sparse_A_pt[i * 128 + thread_idx]; |
| - } |
| - |
| - long AB_block_idx = indices[batch_idx__block_idx]; |
| - float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * 32 * dim]; |
| - float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32 * dim]; |
| - |
| - // [0000000011111111222222223333333344444444555555556666666677777777] |
| - // [0123456701234567012345670123456701234567012345670123456701234567] |
| - int reg_1_idx = thread_idx / 8; |
| - int reg_2_idx = thread_idx % 8; |
| - |
| - float reg_1[8]; |
| - float reg_2[8]; |
| - |
| - float reg_array[16]; |
| - |
| - for (int dim_stride = 0; dim_stride < dim; dim_stride = dim_stride + 64) { |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 16; i++) { |
| - B_buffer[i * 128 + thread_idx] = dense_B_pt[dim_stride * 32 + i * 128 + thread_idx]; |
| - } |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 16; i++) { |
| - reg_array[i] = 0; |
| - } |
| - |
| - __syncthreads(); |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - reg_1[i] = B_buffer[(reg_1_idx * 4 + i) * 32]; |
| - reg_2[i] = A_buffer[reg_2_idx * 4 + i]; |
| - } |
| - |
| - #pragma unroll |
| - for (int mini_dim_idx = 1; mini_dim_idx < 32; mini_dim_idx++) { |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - reg_1[(mini_dim_idx % 2) * 4 + i] = B_buffer[(reg_1_idx * 4 + i) * 32 + mini_dim_idx]; |
| - reg_2[(mini_dim_idx % 2) * 4 + i] = A_buffer[mini_dim_idx * 32 + reg_2_idx * 4 + i]; |
| - } |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j]; |
| - } |
| - } |
| - } |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j]; |
| - } |
| - } |
| - |
| - __syncthreads(); |
| - |
| - float *C_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [64, 32] |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 4; i++) { |
| - #pragma unroll |
| - for (int j = 0; j < 4; j++) { |
| - C_buffer[(reg_1_idx * 4 + i) * 32 + reg_2_idx * 4 + j] = reg_array[i * 4 + j]; |
| - } |
| - } |
| - __syncthreads(); |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 16; i++) { |
| - atomicAdd(&dense_C_pt[dim_stride * 32 + i * 128 + thread_idx], C_buffer[i * 128 + thread_idx]); |
| - } |
| - __syncthreads(); |
| - |
| - } |
| - |
| -} |
| - |
| - |
| -__global__ void reduce_sum_cuda_kernel( |
| - float *sparse_A, // [batch_size, num_block, 32, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *dense_C, // [batch_size, A_num_block, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long num_block |
| -) { |
| - |
| - long batch_idx = blockIdx.y; |
| - long block_idx = blockIdx.x * blockDim.y + threadIdx.y; |
| - |
| - long thread_idx = threadIdx.x; |
| - |
| - long batch_idx__block_idx = batch_idx * num_block + block_idx; |
| - |
| - long AB_block_idx = indices[batch_idx__block_idx]; |
| - float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024]; |
| - |
| - float reg_array[16]; |
| - float value = 0; |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 8; i++) { |
| - reg_array[i] = sparse_A_pt[i * 32 + thread_idx]; |
| - } |
| - #pragma unroll |
| - for (int stride = 8; stride < 32; stride = stride + 8) { |
| - #pragma unroll |
| - for (int i = 0; i < 8; i++) { |
| - reg_array[(stride + i) % 16] = sparse_A_pt[(stride + i) * 32 + thread_idx]; |
| - } |
| - #pragma unroll |
| - for (int i = 0; i < 8; i++) { |
| - value = value + reg_array[(stride - 8 + i) % 16]; |
| - } |
| - } |
| - #pragma unroll |
| - for (int i = 0; i < 8; i++) { |
| - value = value + reg_array[8 + i]; |
| - } |
| - |
| - float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32]; |
| - |
| - atomicAdd(&dense_C_pt[thread_idx], value); |
| - |
| -} |
| - |
| -__global__ void scatter_cuda_kernel( |
| - float *dense_A, // [batch_size, A_num_block, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *sparse_C, // [batch_size, num_block, 32, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long num_block |
| -) { |
| - |
| - long batch_idx = blockIdx.y; |
| - long block_idx = blockIdx.x * blockDim.y + threadIdx.y; |
| - |
| - long thread_idx = threadIdx.x; |
| - |
| - long batch_idx__block_idx = batch_idx * num_block + block_idx; |
| - |
| - long AB_block_idx = indices[batch_idx__block_idx]; |
| - float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32]; |
| - float *sparse_C_pt = &sparse_C[(batch_idx * num_block + block_idx) * 1024]; |
| - |
| - float value = dense_A_pt[thread_idx]; |
| - |
| - #pragma unroll |
| - for (int i = 0; i < 32; i++) { |
| - sparse_C_pt[i * 32 + thread_idx] = value; |
| - } |
| - |
| -} |
| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,59 +0,0 @@ |
| - |
| -#define WARP_SIZE 32 |
| -#define FULL_MASK 0xffffffff |
| -#define OPTIMAL_THREADS 256 |
| - |
| -__global__ void index_max_cuda_kernel( |
| - float *index_vals, // [batch_size, 32, num_block] |
| - int *indices, // [batch_size, num_block] |
| - float *max_vals, // [batch_size, A_num_block * 32] |
| - float *max_vals_scatter, // [batch_size, 32, num_block] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long num_block |
| -); |
| - |
| -__global__ void mm_to_sparse_cuda_kernel( |
| - float *dense_A, // [batch_size, A_num_block, dim, 32] |
| - float *dense_B, // [batch_size, B_num_block, dim, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *sparse_C, // [batch_size, num_block, 32, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long dim, |
| - long num_block |
| -); |
| - |
| -__global__ void sparse_dense_mm_cuda_kernel( |
| - float *sparse_A, // [batch_size, num_block, 32, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *dense_B, // [batch_size, B_num_block, dim, 32] |
| - float *dense_C, // [batch_size, A_num_block, dim, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long dim, |
| - long num_block |
| -); |
| - |
| -__global__ void reduce_sum_cuda_kernel( |
| - float *sparse_A, // [batch_size, num_block, 32, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *dense_C, // [batch_size, A_num_block, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long num_block |
| -); |
| - |
| -__global__ void scatter_cuda_kernel( |
| - float *dense_A, // [batch_size, A_num_block, 32] |
| - int *indices, // [batch_size, num_block] |
| - float *sparse_C, // [batch_size, num_block, 32, 32] |
| - long batch_size, |
| - long A_num_block, |
| - long B_num_block, |
| - long num_block |
| -); |
| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,154 +0,0 @@ |
| -#include <torch/extension.h> |
| -#include <ATen/ATen.h> |
| -#include "cuda_launch.h" |
| -#include "cuda_kernel.h" |
| -#include <vector> |
| - |
| -////////////////////////////////////////////////////////////////////////////////////////////////// |
| -////////////////////////////////////////////////////////////////////////////////////////////////// |
| - |
| -std::vector<at::Tensor> index_max_kernel( |
| - at::Tensor index_vals, // [batch_size, 32, num_block] |
| - at::Tensor indices, // [batch_size, num_block], |
| - int A_num_block, |
| - int B_num_block |
| -) { |
| - int batch_size = indices.size(0); |
| - int num_block = indices.size(1); |
| - |
| - at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options()); |
| - at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options()); |
| - |
| - dim3 threads(256); |
| - dim3 blocks(batch_size); |
| - int shared_mem = A_num_block * 32 * sizeof(float); |
| - |
| - index_max_cuda_kernel<<<blocks, threads, shared_mem>>>( |
| - index_vals.data_ptr<float>(), |
| - indices.data_ptr<int>(), |
| - max_vals.data_ptr<float>(), |
| - max_vals_scatter.data_ptr<float>(), |
| - batch_size, |
| - A_num_block, |
| - B_num_block, |
| - num_block |
| - ); |
| - |
| - return {max_vals, max_vals_scatter}; |
| -} |
| - |
| -at::Tensor mm_to_sparse_kernel( |
| - at::Tensor dense_A, // [batch_size, A_num_block, dim, 32] |
| - at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] |
| - at::Tensor indices // [batch_size, num_block] |
| -) { |
| - int batch_size = dense_A.size(0); |
| - int A_num_block = dense_A.size(1); |
| - int B_num_block = dense_B.size(1); |
| - int dim = dense_A.size(2); |
| - int num_block = indices.size(1); |
| - |
| - at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); |
| - |
| - dim3 threads(64, 4); |
| - dim3 blocks(num_block / 4, batch_size); |
| - |
| - mm_to_sparse_cuda_kernel<<<blocks, threads>>>( |
| - dense_A.data_ptr<float>(), |
| - dense_B.data_ptr<float>(), |
| - indices.data_ptr<int>(), |
| - sparse_C.data_ptr<float>(), |
| - batch_size, |
| - A_num_block, |
| - B_num_block, |
| - dim, |
| - num_block |
| - ); |
| - |
| - return sparse_C; |
| -} |
| - |
| -at::Tensor sparse_dense_mm_kernel( |
| - at::Tensor sparse_A, // [batch_size, num_block, 32, 32] |
| - at::Tensor indices, // [batch_size, num_block] |
| - at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] |
| - int A_num_block |
| -) { |
| - int batch_size = sparse_A.size(0); |
| - int num_block = sparse_A.size(1); |
| - int B_num_block = dense_B.size(1); |
| - int dim = dense_B.size(2); |
| - |
| - at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options()); |
| - |
| - dim3 threads(128, 2); |
| - dim3 blocks(num_block / 2, batch_size); |
| - |
| - sparse_dense_mm_cuda_kernel<<<blocks, threads>>>( |
| - sparse_A.data_ptr<float>(), |
| - indices.data_ptr<int>(), |
| - dense_B.data_ptr<float>(), |
| - dense_C.data_ptr<float>(), |
| - batch_size, |
| - A_num_block, |
| - B_num_block, |
| - dim, |
| - num_block |
| - ); |
| - |
| - return dense_C; |
| -} |
| - |
| -at::Tensor reduce_sum_kernel( |
| - at::Tensor sparse_A, // [batch_size, num_block, 32, 32] |
| - at::Tensor indices, // [batch_size, num_block] |
| - int A_num_block, |
| - int B_num_block |
| -) { |
| - int batch_size = sparse_A.size(0); |
| - int num_block = sparse_A.size(1); |
| - |
| - at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options()); |
| - |
| - dim3 threads(32, 4); |
| - dim3 blocks(num_block / 4, batch_size); |
| - |
| - reduce_sum_cuda_kernel<<<blocks, threads>>>( |
| - sparse_A.data_ptr<float>(), |
| - indices.data_ptr<int>(), |
| - dense_C.data_ptr<float>(), |
| - batch_size, |
| - A_num_block, |
| - B_num_block, |
| - num_block |
| - ); |
| - |
| - return dense_C; |
| -} |
| - |
| -at::Tensor scatter_kernel( |
| - at::Tensor dense_A, // [batch_size, A_num_block, 32] |
| - at::Tensor indices, // [batch_size, num_block] |
| - int B_num_block |
| -) { |
| - int batch_size = dense_A.size(0); |
| - int A_num_block = dense_A.size(1); |
| - int num_block = indices.size(1); |
| - |
| - at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); |
| - |
| - dim3 threads(32, 4); |
| - dim3 blocks(num_block / 4, batch_size); |
| - |
| - scatter_cuda_kernel<<<blocks, threads>>>( |
| - dense_A.data_ptr<float>(), |
| - indices.data_ptr<int>(), |
| - sparse_C.data_ptr<float>(), |
| - batch_size, |
| - A_num_block, |
| - B_num_block, |
| - num_block |
| - ); |
| - |
| - return sparse_C; |
| -} |
| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,39 +0,0 @@ |
| -#include <torch/extension.h> |
| -#include <ATen/ATen.h> |
| -#include <vector> |
| - |
| -#define min(a, b) ((a)<(b)?(a):(b)) |
| -#define max(a, b) ((a)>(b)?(a):(b)) |
| - |
| -std::vector<at::Tensor> index_max_kernel( |
| - at::Tensor index_vals, |
| - at::Tensor indices, |
| - int A_num_block, |
| - int B_num_block |
| -); |
| - |
| -at::Tensor mm_to_sparse_kernel( |
| - at::Tensor dense_A, |
| - at::Tensor dense_B, |
| - at::Tensor indices |
| -); |
| - |
| -at::Tensor sparse_dense_mm_kernel( |
| - at::Tensor sparse_A, |
| - at::Tensor indices, |
| - at::Tensor dense_B, |
| - int A_num_block |
| -); |
| - |
| -at::Tensor reduce_sum_kernel( |
| - at::Tensor sparse_A, |
| - at::Tensor indices, |
| - int A_num_block, |
| - int B_num_block |
| -); |
| - |
| -at::Tensor scatter_kernel( |
| - at::Tensor dense_A, |
| - at::Tensor indices, |
| - int B_num_block |
| -); |
| |
| deleted file mode 100644 |
| |
| |
| |
| @@ -1,78 +0,0 @@ |
| -#include <torch/extension.h> |
| -#include <ATen/ATen.h> |
| -#include "cuda_launch.h" |
| -#include <vector> |
| - |
| -std::vector<at::Tensor> index_max( |
| - at::Tensor index_vals, |
| - at::Tensor indices, |
| - int A_num_block, |
| - int B_num_block |
| -) { |
| - return index_max_kernel( |
| - index_vals, |
| - indices, |
| - A_num_block, |
| - B_num_block |
| - ); |
| -} |
| - |
| -at::Tensor mm_to_sparse( |
| - at::Tensor dense_A, |
| - at::Tensor dense_B, |
| - at::Tensor indices |
| -) { |
| - return mm_to_sparse_kernel( |
| - dense_A, |
| - dense_B, |
| - indices |
| - ); |
| -} |
| - |
| -at::Tensor sparse_dense_mm( |
| - at::Tensor sparse_A, |
| - at::Tensor indices, |
| - at::Tensor dense_B, |
| - int A_num_block |
| -) { |
| - return sparse_dense_mm_kernel( |
| - sparse_A, |
| - indices, |
| - dense_B, |
| - A_num_block |
| - ); |
| -} |
| - |
| -at::Tensor reduce_sum( |
| - at::Tensor sparse_A, |
| - at::Tensor indices, |
| - int A_num_block, |
| - int B_num_block |
| -) { |
| - return reduce_sum_kernel( |
| - sparse_A, |
| - indices, |
| - A_num_block, |
| - B_num_block |
| - ); |
| -} |
| - |
| -at::Tensor scatter( |
| - at::Tensor dense_A, |
| - at::Tensor indices, |
| - int B_num_block |
| -) { |
| - return scatter_kernel( |
| - dense_A, |
| - indices, |
| - B_num_block |
| - ); |
| -} |
| - |
| -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| - m.def("index_max", &index_max, "index_max (CUDA)"); |
| - m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)"); |
| - m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)"); |
| - m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)"); |
| - m.def("scatter", &scatter, "scatter (CUDA)"); |
| -} |
| |
| |
| |
| |
| @@ -15,13 +15,11 @@ |
| """PyTorch MRA model.""" |
| |
| import math |
| -from pathlib import Path |
| from typing import Optional, Union |
| |
| import torch |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| -from torch.utils.cpp_extension import load |
| |
| from ...activations import ACT2FN |
| from ...modeling_layers import GradientCheckpointingLayer |
| @@ -35,7 +33,14 @@ |
| ) |
| from ...modeling_utils import PreTrainedModel |
| from ...pytorch_utils import apply_chunking_to_forward |
| -from ...utils import auto_docstring, is_cuda_platform, is_ninja_available, is_torch_cuda_available, logging |
| +from ...utils import ( |
| + auto_docstring, |
| + is_cuda_platform, |
| + is_kernels_available, |
| + is_ninja_available, |
| + is_torch_cuda_available, |
| + logging, |
| +) |
| from .configuration_mra import MraConfig |
| |
| |
| @@ -46,14 +51,11 @@ |
| |
| def load_cuda_kernels(): |
| global mra_cuda_kernel |
| - src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra" |
| - |
| - def append_root(files): |
| - return [src_folder / file for file in files] |
| - |
| - src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"]) |
| + if not is_kernels_available(): |
| + raise ImportError("kernels is not installed, please install it with `pip install kernels`") |
| + from kernels import get_kernel |
| |
| - mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True) |
| + mra_cuda_kernel = get_kernel("kernels-community/mra") |
| |
| |
| def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block): |
|
|