harness / diffs /41507.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/kernels/mra/cuda_kernel.cu b/src/transformers/kernels/mra/cuda_kernel.cu
deleted file mode 100644
index 87ed89052873..000000000000
--- a/src/transformers/kernels/mra/cuda_kernel.cu
+++ /dev/null
@@ -1,383 +0,0 @@
-#include "cuda_kernel.h"
-
-//////////////////////////////////////////////////////////////////////////////////////////////////
-//////////////////////////////////////////////////////////////////////////////////////////////////
-
-__global__ void index_max_cuda_kernel(
- float *index_vals, // [batch_size, 32, num_block]
- int *indices, // [batch_size, num_block]
- float *max_vals, // [batch_size, A_num_block * 32]
- float *max_vals_scatter, // [batch_size, 32, num_block]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long num_block
-) {
-
- long batch_idx = blockIdx.x;
-
- long thread_idx = threadIdx.x;
- long num_thread = blockDim.x;
-
- extern __shared__ float buffer[];
- int *max_buffer = (int*)buffer;
-
- for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
- int idx = i + thread_idx;
- if (idx < A_num_block * 32) {
- max_buffer[idx] = -1e8;
- }
- }
- __syncthreads();
-
- int *indices_pt = &indices[batch_idx * num_block];
- float *index_vals_pt = &index_vals[batch_idx * num_block * 32];
-
- for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
- int idx = idx_start + thread_idx;
- int A_block_idx = indices_pt[idx % num_block] / B_num_block;
- atomicMax(&max_buffer[A_block_idx * 32 + idx / num_block], (int)(index_vals_pt[idx] * 1000));
- }
- __syncthreads();
-
- float *max_vals_pt = &max_vals[batch_idx * A_num_block * 32];
- for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
- int idx = i + thread_idx;
- if (idx < A_num_block * 32) {
- max_vals_pt[idx] = (float)max_buffer[idx] / 1000.;
- }
- }
-
- float *max_vals_scatter_pt = &max_vals_scatter[batch_idx * num_block * 32];
- for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
- int idx = idx_start + thread_idx;
- int A_block_idx = indices_pt[idx % num_block] / B_num_block;
- max_vals_scatter_pt[idx] = (float)max_buffer[A_block_idx * 32 + idx / num_block] / 1000.;
- }
-
-}
-
-__global__ void mm_to_sparse_cuda_kernel(
- float *dense_A, // [batch_size, A_num_block, dim, 32]
- float *dense_B, // [batch_size, B_num_block, dim, 32]
- int *indices, // [batch_size, num_block]
- float *sparse_C, // [batch_size, num_block, 32, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long dim,
- long num_block
-) {
-
- long batch_idx = blockIdx.y;
- long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
-
- long thread_idx = threadIdx.x;
-
- __shared__ float buffer[4096];
- float *A_buffer = &buffer[threadIdx.y * 1024]; // [2, 8, 32]
- float *B_buffer = &buffer[threadIdx.y * 1024 + 512]; // [2, 8, 32]
-
- long batch_idx__block_idx = batch_idx * num_block + block_idx;
-
- long AB_block_idx = indices[batch_idx__block_idx];
- float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * dim * 32];
- float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * dim * 32];
-
- int reg_1_idx = thread_idx / 8; // [0000000011111111222222223333333344444444555555556666666677777777]
- int reg_2_idx = thread_idx % 8; // [0123456701234567012345670123456701234567012345670123456701234567]
-
- float reg_1[8];
- float reg_2[8];
-
- float reg_array[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- A_buffer[i * 64 + thread_idx] = dense_A_pt[i * 64 + thread_idx];
- B_buffer[i * 64 + thread_idx] = dense_B_pt[i * 64 + thread_idx];
- }
-
- __syncthreads();
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- reg_1[i] = A_buffer[reg_1_idx * 4 + i];
- reg_2[i] = B_buffer[reg_2_idx * 4 + i];
- }
-
- for (int dim_stride = 1; dim_stride < (dim / 8); dim_stride++) {
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- A_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_A_pt[dim_stride * 256 + i * 64 + thread_idx];
- B_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_B_pt[dim_stride * 256 + i * 64 + thread_idx];
- }
-
- #pragma unroll
- for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
- reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
- }
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
- }
- }
- }
-
- __syncthreads();
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- reg_1[i] = A_buffer[(dim_stride % 2) * 256 + reg_1_idx * 4 + i];
- reg_2[i] = B_buffer[(dim_stride % 2) * 256 + reg_2_idx * 4 + i];
- }
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
- }
- }
-
- }
-
- #pragma unroll
- for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
- reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
- }
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
- }
- }
- }
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
- }
- }
- __syncthreads();
-
- float *C_buffer = &buffer[threadIdx.y * 1024]; // [32, 32]
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- C_buffer[(reg_2_idx * 4 + j) * 32 + reg_1_idx * 4 + i] = reg_array[i * 4 + j];
- }
- }
- __syncthreads();
-
- float *sparse_C_pt = &sparse_C[batch_idx__block_idx * 1024];
-
- #pragma unroll
- for (int i = 0; i < 16; i++) {
- sparse_C_pt[i * 64 + thread_idx] = C_buffer[i * 64 + thread_idx];
- }
-
-}
-
-__global__ void sparse_dense_mm_cuda_kernel(
- float *sparse_A, // [batch_size, num_block, 32, 32]
- int *indices, // [batch_size, num_block]
- float *dense_B, // [batch_size, B_num_block, dim, 32]
- float *dense_C, // [batch_size, A_num_block, dim, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long dim,
- long num_block
-) {
-
- long batch_idx = blockIdx.y;
- long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
-
- long thread_idx = threadIdx.x;
-
- __shared__ float buffer[6144];
- float *A_buffer = &buffer[threadIdx.y * 3072]; // [32, 32]
- float *B_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [32, 64]
-
- long batch_idx__block_idx = batch_idx * num_block + block_idx;
-
- float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- A_buffer[i * 128 + thread_idx] = sparse_A_pt[i * 128 + thread_idx];
- }
-
- long AB_block_idx = indices[batch_idx__block_idx];
- float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * 32 * dim];
- float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32 * dim];
-
- // [0000000011111111222222223333333344444444555555556666666677777777]
- // [0123456701234567012345670123456701234567012345670123456701234567]
- int reg_1_idx = thread_idx / 8;
- int reg_2_idx = thread_idx % 8;
-
- float reg_1[8];
- float reg_2[8];
-
- float reg_array[16];
-
- for (int dim_stride = 0; dim_stride < dim; dim_stride = dim_stride + 64) {
-
- #pragma unroll
- for (int i = 0; i < 16; i++) {
- B_buffer[i * 128 + thread_idx] = dense_B_pt[dim_stride * 32 + i * 128 + thread_idx];
- }
-
- #pragma unroll
- for (int i = 0; i < 16; i++) {
- reg_array[i] = 0;
- }
-
- __syncthreads();
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- reg_1[i] = B_buffer[(reg_1_idx * 4 + i) * 32];
- reg_2[i] = A_buffer[reg_2_idx * 4 + i];
- }
-
- #pragma unroll
- for (int mini_dim_idx = 1; mini_dim_idx < 32; mini_dim_idx++) {
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- reg_1[(mini_dim_idx % 2) * 4 + i] = B_buffer[(reg_1_idx * 4 + i) * 32 + mini_dim_idx];
- reg_2[(mini_dim_idx % 2) * 4 + i] = A_buffer[mini_dim_idx * 32 + reg_2_idx * 4 + i];
- }
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
- }
- }
- }
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
- }
- }
-
- __syncthreads();
-
- float *C_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [64, 32]
-
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- C_buffer[(reg_1_idx * 4 + i) * 32 + reg_2_idx * 4 + j] = reg_array[i * 4 + j];
- }
- }
- __syncthreads();
-
- #pragma unroll
- for (int i = 0; i < 16; i++) {
- atomicAdd(&dense_C_pt[dim_stride * 32 + i * 128 + thread_idx], C_buffer[i * 128 + thread_idx]);
- }
- __syncthreads();
-
- }
-
-}
-
-
-__global__ void reduce_sum_cuda_kernel(
- float *sparse_A, // [batch_size, num_block, 32, 32]
- int *indices, // [batch_size, num_block]
- float *dense_C, // [batch_size, A_num_block, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long num_block
-) {
-
- long batch_idx = blockIdx.y;
- long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
-
- long thread_idx = threadIdx.x;
-
- long batch_idx__block_idx = batch_idx * num_block + block_idx;
-
- long AB_block_idx = indices[batch_idx__block_idx];
- float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
-
- float reg_array[16];
- float value = 0;
-
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- reg_array[i] = sparse_A_pt[i * 32 + thread_idx];
- }
- #pragma unroll
- for (int stride = 8; stride < 32; stride = stride + 8) {
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- reg_array[(stride + i) % 16] = sparse_A_pt[(stride + i) * 32 + thread_idx];
- }
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- value = value + reg_array[(stride - 8 + i) % 16];
- }
- }
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- value = value + reg_array[8 + i];
- }
-
- float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
-
- atomicAdd(&dense_C_pt[thread_idx], value);
-
-}
-
-__global__ void scatter_cuda_kernel(
- float *dense_A, // [batch_size, A_num_block, 32]
- int *indices, // [batch_size, num_block]
- float *sparse_C, // [batch_size, num_block, 32, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long num_block
-) {
-
- long batch_idx = blockIdx.y;
- long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
-
- long thread_idx = threadIdx.x;
-
- long batch_idx__block_idx = batch_idx * num_block + block_idx;
-
- long AB_block_idx = indices[batch_idx__block_idx];
- float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
- float *sparse_C_pt = &sparse_C[(batch_idx * num_block + block_idx) * 1024];
-
- float value = dense_A_pt[thread_idx];
-
- #pragma unroll
- for (int i = 0; i < 32; i++) {
- sparse_C_pt[i * 32 + thread_idx] = value;
- }
-
-}
diff --git a/src/transformers/kernels/mra/cuda_kernel.h b/src/transformers/kernels/mra/cuda_kernel.h
deleted file mode 100644
index a95b46f7d159..000000000000
--- a/src/transformers/kernels/mra/cuda_kernel.h
+++ /dev/null
@@ -1,59 +0,0 @@
-
-#define WARP_SIZE 32
-#define FULL_MASK 0xffffffff
-#define OPTIMAL_THREADS 256
-
-__global__ void index_max_cuda_kernel(
- float *index_vals, // [batch_size, 32, num_block]
- int *indices, // [batch_size, num_block]
- float *max_vals, // [batch_size, A_num_block * 32]
- float *max_vals_scatter, // [batch_size, 32, num_block]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long num_block
-);
-
-__global__ void mm_to_sparse_cuda_kernel(
- float *dense_A, // [batch_size, A_num_block, dim, 32]
- float *dense_B, // [batch_size, B_num_block, dim, 32]
- int *indices, // [batch_size, num_block]
- float *sparse_C, // [batch_size, num_block, 32, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long dim,
- long num_block
-);
-
-__global__ void sparse_dense_mm_cuda_kernel(
- float *sparse_A, // [batch_size, num_block, 32, 32]
- int *indices, // [batch_size, num_block]
- float *dense_B, // [batch_size, B_num_block, dim, 32]
- float *dense_C, // [batch_size, A_num_block, dim, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long dim,
- long num_block
-);
-
-__global__ void reduce_sum_cuda_kernel(
- float *sparse_A, // [batch_size, num_block, 32, 32]
- int *indices, // [batch_size, num_block]
- float *dense_C, // [batch_size, A_num_block, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long num_block
-);
-
-__global__ void scatter_cuda_kernel(
- float *dense_A, // [batch_size, A_num_block, 32]
- int *indices, // [batch_size, num_block]
- float *sparse_C, // [batch_size, num_block, 32, 32]
- long batch_size,
- long A_num_block,
- long B_num_block,
- long num_block
-);
diff --git a/src/transformers/kernels/mra/cuda_launch.cu b/src/transformers/kernels/mra/cuda_launch.cu
deleted file mode 100644
index ba2a0cacfe61..000000000000
--- a/src/transformers/kernels/mra/cuda_launch.cu
+++ /dev/null
@@ -1,154 +0,0 @@
-#include <torch/extension.h>
-#include <ATen/ATen.h>
-#include "cuda_launch.h"
-#include "cuda_kernel.h"
-#include <vector>
-
-//////////////////////////////////////////////////////////////////////////////////////////////////
-//////////////////////////////////////////////////////////////////////////////////////////////////
-
-std::vector<at::Tensor> index_max_kernel(
- at::Tensor index_vals, // [batch_size, 32, num_block]
- at::Tensor indices, // [batch_size, num_block],
- int A_num_block,
- int B_num_block
-) {
- int batch_size = indices.size(0);
- int num_block = indices.size(1);
-
- at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options());
- at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options());
-
- dim3 threads(256);
- dim3 blocks(batch_size);
- int shared_mem = A_num_block * 32 * sizeof(float);
-
- index_max_cuda_kernel<<<blocks, threads, shared_mem>>>(
- index_vals.data_ptr<float>(),
- indices.data_ptr<int>(),
- max_vals.data_ptr<float>(),
- max_vals_scatter.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- num_block
- );
-
- return {max_vals, max_vals_scatter};
-}
-
-at::Tensor mm_to_sparse_kernel(
- at::Tensor dense_A, // [batch_size, A_num_block, dim, 32]
- at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
- at::Tensor indices // [batch_size, num_block]
-) {
- int batch_size = dense_A.size(0);
- int A_num_block = dense_A.size(1);
- int B_num_block = dense_B.size(1);
- int dim = dense_A.size(2);
- int num_block = indices.size(1);
-
- at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
-
- dim3 threads(64, 4);
- dim3 blocks(num_block / 4, batch_size);
-
- mm_to_sparse_cuda_kernel<<<blocks, threads>>>(
- dense_A.data_ptr<float>(),
- dense_B.data_ptr<float>(),
- indices.data_ptr<int>(),
- sparse_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- dim,
- num_block
- );
-
- return sparse_C;
-}
-
-at::Tensor sparse_dense_mm_kernel(
- at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
- at::Tensor indices, // [batch_size, num_block]
- at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
- int A_num_block
-) {
- int batch_size = sparse_A.size(0);
- int num_block = sparse_A.size(1);
- int B_num_block = dense_B.size(1);
- int dim = dense_B.size(2);
-
- at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options());
-
- dim3 threads(128, 2);
- dim3 blocks(num_block / 2, batch_size);
-
- sparse_dense_mm_cuda_kernel<<<blocks, threads>>>(
- sparse_A.data_ptr<float>(),
- indices.data_ptr<int>(),
- dense_B.data_ptr<float>(),
- dense_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- dim,
- num_block
- );
-
- return dense_C;
-}
-
-at::Tensor reduce_sum_kernel(
- at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
- at::Tensor indices, // [batch_size, num_block]
- int A_num_block,
- int B_num_block
-) {
- int batch_size = sparse_A.size(0);
- int num_block = sparse_A.size(1);
-
- at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options());
-
- dim3 threads(32, 4);
- dim3 blocks(num_block / 4, batch_size);
-
- reduce_sum_cuda_kernel<<<blocks, threads>>>(
- sparse_A.data_ptr<float>(),
- indices.data_ptr<int>(),
- dense_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- num_block
- );
-
- return dense_C;
-}
-
-at::Tensor scatter_kernel(
- at::Tensor dense_A, // [batch_size, A_num_block, 32]
- at::Tensor indices, // [batch_size, num_block]
- int B_num_block
-) {
- int batch_size = dense_A.size(0);
- int A_num_block = dense_A.size(1);
- int num_block = indices.size(1);
-
- at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
-
- dim3 threads(32, 4);
- dim3 blocks(num_block / 4, batch_size);
-
- scatter_cuda_kernel<<<blocks, threads>>>(
- dense_A.data_ptr<float>(),
- indices.data_ptr<int>(),
- sparse_C.data_ptr<float>(),
- batch_size,
- A_num_block,
- B_num_block,
- num_block
- );
-
- return sparse_C;
-}
diff --git a/src/transformers/kernels/mra/cuda_launch.h b/src/transformers/kernels/mra/cuda_launch.h
deleted file mode 100644
index 0200140ee337..000000000000
--- a/src/transformers/kernels/mra/cuda_launch.h
+++ /dev/null
@@ -1,39 +0,0 @@
-#include <torch/extension.h>
-#include <ATen/ATen.h>
-#include <vector>
-
-#define min(a, b) ((a)<(b)?(a):(b))
-#define max(a, b) ((a)>(b)?(a):(b))
-
-std::vector<at::Tensor> index_max_kernel(
- at::Tensor index_vals,
- at::Tensor indices,
- int A_num_block,
- int B_num_block
-);
-
-at::Tensor mm_to_sparse_kernel(
- at::Tensor dense_A,
- at::Tensor dense_B,
- at::Tensor indices
-);
-
-at::Tensor sparse_dense_mm_kernel(
- at::Tensor sparse_A,
- at::Tensor indices,
- at::Tensor dense_B,
- int A_num_block
-);
-
-at::Tensor reduce_sum_kernel(
- at::Tensor sparse_A,
- at::Tensor indices,
- int A_num_block,
- int B_num_block
-);
-
-at::Tensor scatter_kernel(
- at::Tensor dense_A,
- at::Tensor indices,
- int B_num_block
-);
diff --git a/src/transformers/kernels/mra/torch_extension.cpp b/src/transformers/kernels/mra/torch_extension.cpp
deleted file mode 100644
index 60c9262b7792..000000000000
--- a/src/transformers/kernels/mra/torch_extension.cpp
+++ /dev/null
@@ -1,78 +0,0 @@
-#include <torch/extension.h>
-#include <ATen/ATen.h>
-#include "cuda_launch.h"
-#include <vector>
-
-std::vector<at::Tensor> index_max(
- at::Tensor index_vals,
- at::Tensor indices,
- int A_num_block,
- int B_num_block
-) {
- return index_max_kernel(
- index_vals,
- indices,
- A_num_block,
- B_num_block
- );
-}
-
-at::Tensor mm_to_sparse(
- at::Tensor dense_A,
- at::Tensor dense_B,
- at::Tensor indices
-) {
- return mm_to_sparse_kernel(
- dense_A,
- dense_B,
- indices
- );
-}
-
-at::Tensor sparse_dense_mm(
- at::Tensor sparse_A,
- at::Tensor indices,
- at::Tensor dense_B,
- int A_num_block
-) {
- return sparse_dense_mm_kernel(
- sparse_A,
- indices,
- dense_B,
- A_num_block
- );
-}
-
-at::Tensor reduce_sum(
- at::Tensor sparse_A,
- at::Tensor indices,
- int A_num_block,
- int B_num_block
-) {
- return reduce_sum_kernel(
- sparse_A,
- indices,
- A_num_block,
- B_num_block
- );
-}
-
-at::Tensor scatter(
- at::Tensor dense_A,
- at::Tensor indices,
- int B_num_block
-) {
- return scatter_kernel(
- dense_A,
- indices,
- B_num_block
- );
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("index_max", &index_max, "index_max (CUDA)");
- m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)");
- m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)");
- m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)");
- m.def("scatter", &scatter, "scatter (CUDA)");
-}
diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py
index c80e362d6b93..478d66781851 100644
--- a/src/transformers/models/mra/modeling_mra.py
+++ b/src/transformers/models/mra/modeling_mra.py
@@ -15,13 +15,11 @@
"""PyTorch MRA model."""
import math
-from pathlib import Path
from typing import Optional, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from torch.utils.cpp_extension import load
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
@@ -35,7 +33,14 @@
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
-from ...utils import auto_docstring, is_cuda_platform, is_ninja_available, is_torch_cuda_available, logging
+from ...utils import (
+ auto_docstring,
+ is_cuda_platform,
+ is_kernels_available,
+ is_ninja_available,
+ is_torch_cuda_available,
+ logging,
+)
from .configuration_mra import MraConfig
@@ -46,14 +51,11 @@
def load_cuda_kernels():
global mra_cuda_kernel
- src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra"
-
- def append_root(files):
- return [src_folder / file for file in files]
-
- src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"])
+ if not is_kernels_available():
+ raise ImportError("kernels is not installed, please install it with `pip install kernels`")
+ from kernels import get_kernel
- mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True)
+ mra_cuda_kernel = get_kernel("kernels-community/mra")
def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):