| #include <torch/extension.h> |
| #include <c10/cuda/CUDAStream.h> |
| #include <cuda_runtime.h> |
|
|
| |
| |
| __device__ __forceinline__ int32_t word_dot_masked(uint32_t a, uint32_t b, int bits) { |
| if (bits <= 0) { |
| return 0; |
| } |
| const uint32_t mask = |
| (bits >= 32) ? 0xffffffffu : ((1u << static_cast<unsigned int>(bits)) - 1u); |
| a &= mask; |
| b &= mask; |
| return 4 * static_cast<int32_t>(__popc(a & b)) - |
| 2 * static_cast<int32_t>(__popc(a)) - 2 * static_cast<int32_t>(__popc(b)) + bits; |
| } |
|
|
| __global__ void binary_gemm_kernel( |
| const uint32_t *__restrict__ A, |
| const uint32_t *__restrict__ B, |
| float *__restrict__ C, |
| int M, |
| int K, |
| int L, |
| int original_n) { |
| const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| const int total = M * K; |
| if (idx >= total) { |
| return; |
| } |
| const int m = idx / K; |
| const int k = idx % K; |
| int64_t acc = 0; |
| for (int p = 0; p < L; ++p) { |
| const uint32_t a = A[m * L + p]; |
| const uint32_t b = B[p * K + k]; |
| int bits = (p == L - 1) ? (original_n - 32 * (L - 1)) : 32; |
| if (bits > 32) { |
| bits = 32; |
| } |
| acc += word_dot_masked(a, b, bits); |
| } |
| C[m * K + k] = static_cast<float>(acc); |
| } |
|
|
| torch::Tensor binary_gemm_forward_cuda( |
| torch::Tensor packed_input, |
| torch::Tensor packed_weight_t, |
| int64_t original_n) { |
| const int M = static_cast<int>(packed_input.size(0)); |
| const int L = static_cast<int>(packed_input.size(1)); |
| const int Lw = static_cast<int>(packed_weight_t.size(0)); |
| const int K = static_cast<int>(packed_weight_t.size(1)); |
| TORCH_CHECK(L == Lw, "binary_gemm: packed dim mismatch ", L, " vs ", Lw); |
| TORCH_CHECK(original_n > 0, "binary_gemm: original_n must be positive"); |
|
|
| auto output = torch::empty( |
| {M, K}, |
| torch::dtype(torch::kFloat32).device(packed_input.device())); |
|
|
| const uint32_t *A = packed_input.data_ptr<uint32_t>(); |
| const uint32_t *B = packed_weight_t.data_ptr<uint32_t>(); |
| float *C = output.data_ptr<float>(); |
|
|
| const int total = M * K; |
| const int threads = 256; |
| const int blocks = (total + threads - 1) / threads; |
| const int dev = packed_input.get_device(); |
| cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev).stream(); |
|
|
| binary_gemm_kernel<<<blocks, threads, 0, stream>>>( |
| A, B, C, M, K, L, static_cast<int>(original_n)); |
| cudaError_t err = cudaGetLastError(); |
| TORCH_CHECK(err == cudaSuccess, "binary_gemm_kernel failed: ", cudaGetErrorString(err)); |
|
|
| return output; |
| } |
|
|