flute_kernels / torch-ext /torch_binding.cpp
galqiwi's picture
Initial source: FLUTE kernel scaffold (vendored CUTLASS, split TUs)
67a5826 verified
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"qgemm_raw_simple(Tensor input, Tensor weight, Tensor scales, "
"Tensor table, Tensor table2, Tensor(a!) workspace, "
"int num_bits, int group_size, int template_id, int num_sms) -> Tensor");
ops.def(
"qgemm_raw_simple_hadamard(Tensor input, Tensor weight, Tensor scales, "
"Tensor table, Tensor table2, Tensor(a!) workspace, "
"int num_bits, int group_size, int hadamard_size, "
"int template_id, int num_sms) -> Tensor");
ops.def("hadamard_transform(Tensor(a!) input, bool inplace) -> Tensor");
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
ops.impl("qgemm_raw_simple", c10::kCUDA, &qgemm_raw_simple);
ops.impl("qgemm_raw_simple_hadamard", c10::kCUDA, &qgemm_raw_simple_hadamard);
ops.impl("hadamard_transform", c10::kCUDA, &hadamard_transform);
#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)