#include #include #include // For bits in {0,1} with sign = 2*b-1, sum over `bits` positions: // sum (2*a_i-1)(2*b_i-1) = 4*popc(a&b) - 2*popc(a) - 2*popc(b) + bits __device__ __forceinline__ int32_t word_dot_masked(uint32_t a, uint32_t b, int bits) { if (bits <= 0) { return 0; } const uint32_t mask = (bits >= 32) ? 0xffffffffu : ((1u << static_cast(bits)) - 1u); a &= mask; b &= mask; return 4 * static_cast(__popc(a & b)) - 2 * static_cast(__popc(a)) - 2 * static_cast(__popc(b)) + bits; } __global__ void binary_gemm_kernel( const uint32_t *__restrict__ A, const uint32_t *__restrict__ B, float *__restrict__ C, int M, int K, int L, int original_n) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int total = M * K; if (idx >= total) { return; } const int m = idx / K; const int k = idx % K; int64_t acc = 0; for (int p = 0; p < L; ++p) { const uint32_t a = A[m * L + p]; const uint32_t b = B[p * K + k]; int bits = (p == L - 1) ? (original_n - 32 * (L - 1)) : 32; if (bits > 32) { bits = 32; } acc += word_dot_masked(a, b, bits); } C[m * K + k] = static_cast(acc); } torch::Tensor binary_gemm_forward_cuda( torch::Tensor packed_input, torch::Tensor packed_weight_t, int64_t original_n) { const int M = static_cast(packed_input.size(0)); const int L = static_cast(packed_input.size(1)); const int Lw = static_cast(packed_weight_t.size(0)); const int K = static_cast(packed_weight_t.size(1)); TORCH_CHECK(L == Lw, "binary_gemm: packed dim mismatch ", L, " vs ", Lw); TORCH_CHECK(original_n > 0, "binary_gemm: original_n must be positive"); auto output = torch::empty( {M, K}, torch::dtype(torch::kFloat32).device(packed_input.device())); const uint32_t *A = packed_input.data_ptr(); const uint32_t *B = packed_weight_t.data_ptr(); float *C = output.data_ptr(); const int total = M * K; const int threads = 256; const int blocks = (total + threads - 1) / threads; const int dev = packed_input.get_device(); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(dev).stream(); binary_gemm_kernel<<>>( A, B, C, M, K, L, static_cast(original_n)); cudaError_t err = cudaGetLastError(); TORCH_CHECK(err == cudaSuccess, "binary_gemm_kernel failed: ", cudaGetErrorString(err)); return output; }