YYYYYYUUU's picture
Backup FULL poplab work tree (source, configs, libs, scripts) excl. .pth
08cde47 verified
Raw
History Blame Contribute Delete
2.58 kB
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h>
// For bits in {0,1} with sign = 2*b-1, sum over `bits` positions:
// sum (2*a_i-1)(2*b_i-1) = 4*popc(a&b) - 2*popc(a) - 2*popc(b) + bits
__device__ __forceinline__ int32_t word_dot_masked(uint32_t a, uint32_t b, int bits) {
if (bits <= 0) {
return 0;
}
const uint32_t mask =
(bits >= 32) ? 0xffffffffu : ((1u << static_cast<unsigned int>(bits)) - 1u);
a &= mask;
b &= mask;
return 4 * static_cast<int32_t>(__popc(a & b)) -
2 * static_cast<int32_t>(__popc(a)) - 2 * static_cast<int32_t>(__popc(b)) + bits;
}
__global__ void binary_gemm_kernel(
const uint32_t *__restrict__ A,
const uint32_t *__restrict__ B,
float *__restrict__ C,
int M,
int K,
int L,
int original_n) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total = M * K;
if (idx >= total) {
return;
}
const int m = idx / K;
const int k = idx % K;
int64_t acc = 0;
for (int p = 0; p < L; ++p) {
const uint32_t a = A[m * L + p];
const uint32_t b = B[p * K + k];
int bits = (p == L - 1) ? (original_n - 32 * (L - 1)) : 32;
if (bits > 32) {
bits = 32;
}
acc += word_dot_masked(a, b, bits);
}
C[m * K + k] = static_cast<float>(acc);
}
torch::Tensor binary_gemm_forward_cuda(
torch::Tensor packed_input,
torch::Tensor packed_weight_t,
int64_t original_n) {
const int M = static_cast<int>(packed_input.size(0));
const int L = static_cast<int>(packed_input.size(1));
const int Lw = static_cast<int>(packed_weight_t.size(0));
const int K = static_cast<int>(packed_weight_t.size(1));
TORCH_CHECK(L == Lw, "binary_gemm: packed dim mismatch ", L, " vs ", Lw);
TORCH_CHECK(original_n > 0, "binary_gemm: original_n must be positive");
auto output = torch::empty(
{M, K},
torch::dtype(torch::kFloat32).device(packed_input.device()));
const uint32_t *A = packed_input.data_ptr<uint32_t>();
const uint32_t *B = packed_weight_t.data_ptr<uint32_t>();
float *C = output.data_ptr<float>();
const int total = M * K;
const int threads = 256;
const int blocks = (total + threads - 1) / threads;
const int dev = packed_input.get_device();
cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev).stream();
binary_gemm_kernel<<<blocks, threads, 0, stream>>>(
A, B, C, M, K, L, static_cast<int>(original_n));
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess, "binary_gemm_kernel failed: ", cudaGetErrorString(err));
return output;
}