// PyTorch binding for packed binary GEMM (±1 dot product via uint32 packing). #include torch::Tensor binary_gemm_forward_cuda( torch::Tensor packed_input, torch::Tensor packed_weight_t, int64_t original_n); torch::Tensor binary_gemm_forward( torch::Tensor packed_input, torch::Tensor packed_weight_t, int64_t original_n) { TORCH_CHECK(packed_input.is_cuda(), "binary_gemm: packed_input must be CUDA"); TORCH_CHECK(packed_weight_t.is_cuda(), "binary_gemm: packed_weight_t must be CUDA"); TORCH_CHECK(packed_input.scalar_type() == torch::kUInt32, "binary_gemm: packed_input must be uint32"); TORCH_CHECK(packed_weight_t.scalar_type() == torch::kUInt32, "binary_gemm: packed_weight_t must be uint32"); packed_input = packed_input.contiguous(); packed_weight_t = packed_weight_t.contiguous(); return binary_gemm_forward_cuda(packed_input, packed_weight_t, original_n); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &binary_gemm_forward, "Binary GEMM forward (packed uint32)"); }