YYYYYYUUU's picture
Add core reproduction code (binarization layers, PTv3, superpoint ops, min-repro pack)
7b95dc2 verified
Raw
History Blame Contribute Delete
1.08 kB
// PyTorch binding for packed binary GEMM (±1 dot product via uint32 packing).
#include <torch/extension.h>
torch::Tensor binary_gemm_forward_cuda(
torch::Tensor packed_input,
torch::Tensor packed_weight_t,
int64_t original_n);
torch::Tensor binary_gemm_forward(
torch::Tensor packed_input,
torch::Tensor packed_weight_t,
int64_t original_n) {
TORCH_CHECK(packed_input.is_cuda(), "binary_gemm: packed_input must be CUDA");
TORCH_CHECK(packed_weight_t.is_cuda(), "binary_gemm: packed_weight_t must be CUDA");
TORCH_CHECK(packed_input.scalar_type() == torch::kUInt32,
"binary_gemm: packed_input must be uint32");
TORCH_CHECK(packed_weight_t.scalar_type() == torch::kUInt32,
"binary_gemm: packed_weight_t must be uint32");
packed_input = packed_input.contiguous();
packed_weight_t = packed_weight_t.contiguous();
return binary_gemm_forward_cuda(packed_input, packed_weight_t, original_n);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &binary_gemm_forward, "Binary GEMM forward (packed uint32)");
}