#include #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "qgemm_raw_simple(Tensor input, Tensor weight, Tensor scales, " "Tensor table, Tensor table2, Tensor(a!) workspace, " "int num_bits, int group_size, int template_id, int num_sms) -> Tensor"); ops.def( "qgemm_raw_simple_hadamard(Tensor input, Tensor weight, Tensor scales, " "Tensor table, Tensor table2, Tensor(a!) workspace, " "int num_bits, int group_size, int hadamard_size, " "int template_id, int num_sms) -> Tensor"); ops.def("hadamard_transform(Tensor(a!) input, bool inplace) -> Tensor"); #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) ops.impl("qgemm_raw_simple", c10::kCUDA, &qgemm_raw_simple); ops.impl("qgemm_raw_simple_hadamard", c10::kCUDA, &qgemm_raw_simple_hadamard); ops.impl("hadamard_transform", c10::kCUDA, &hadamard_transform); #endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)