| |
| #include <torch/extension.h> |
|
|
| torch::Tensor binary_gemm_forward_cuda( |
| torch::Tensor packed_input, |
| torch::Tensor packed_weight_t, |
| int64_t original_n); |
|
|
| torch::Tensor binary_gemm_forward( |
| torch::Tensor packed_input, |
| torch::Tensor packed_weight_t, |
| int64_t original_n) { |
| TORCH_CHECK(packed_input.is_cuda(), "binary_gemm: packed_input must be CUDA"); |
| TORCH_CHECK(packed_weight_t.is_cuda(), "binary_gemm: packed_weight_t must be CUDA"); |
| TORCH_CHECK(packed_input.scalar_type() == torch::kUInt32, |
| "binary_gemm: packed_input must be uint32"); |
| TORCH_CHECK(packed_weight_t.scalar_type() == torch::kUInt32, |
| "binary_gemm: packed_weight_t must be uint32"); |
| packed_input = packed_input.contiguous(); |
| packed_weight_t = packed_weight_t.contiguous(); |
| return binary_gemm_forward_cuda(packed_input, packed_weight_t, original_n); |
| } |
|
|
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| m.def("forward", &binary_gemm_forward, "Binary GEMM forward (packed uint32)"); |
| } |
|
|