Instructions to use galqiwi/flute_kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use galqiwi/flute_kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("galqiwi/flute_kernels") - Notebooks
- Google Colab
- Kaggle
| at::Tensor | |
| hadamard_transform(at::Tensor& in, | |
| bool inplace); | |
| template < | |
| typename T, | |
| typename TQ, | |
| typename T2, | |
| typename NumBits, | |
| typename GroupSize | |
| > | |
| void | |
| _qgemm_raw(int M, | |
| int N, | |
| int K, | |
| int P, | |
| const T * const __restrict__ A, | |
| const TQ* const __restrict__ Q, | |
| T * __restrict__ D, | |
| const T * const __restrict__ S, | |
| const T * const __restrict__ QM, | |
| const T2* const __restrict__ QM2, | |
| void* __restrict__ workspace, | |
| const int template_id, | |
| const int num_sms, | |
| const cudaStream_t stream); | |
| template < | |
| typename T, | |
| typename NumBits, | |
| typename GroupSize | |
| > | |
| void | |
| qgemm_raw(const at::Tensor& input, | |
| const at::Tensor& weight, | |
| at::Tensor& output, | |
| const at::Tensor& scales, | |
| const at::Tensor& table, | |
| const at::Tensor& table2, | |
| at::Tensor& workspace, | |
| const int template_id, | |
| const int num_sms, | |
| const cudaStream_t stream) | |
| { | |
| using namespace cute; | |
| using TQ = cute::uint16_t; | |
| using T2 = conditional_t<is_same_v<T, half_t>, __half2, __nv_bfloat162>; | |
| _qgemm_raw< | |
| T, | |
| TQ, | |
| T2, | |
| NumBits, | |
| GroupSize | |
| > ( | |
| output.size(0), // M | |
| output.size(1), // N | |
| input .size(1), // K | |
| weight.size(0), // P | |
| reinterpret_cast<const T *>(input .data_ptr()), | |
| reinterpret_cast<const TQ*>(weight .data_ptr()), | |
| reinterpret_cast< T *>(output .data_ptr()), | |
| reinterpret_cast<const T *>(scales .data_ptr()), | |
| reinterpret_cast<const T *>(table .data_ptr()), | |
| reinterpret_cast<const T2*>(table2 .data_ptr()), | |
| reinterpret_cast< void*>(workspace.data_ptr()), | |
| template_id, | |
| num_sms, | |
| stream); | |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); | |
| } | |
| at::Tensor | |
| qgemm_raw_simple(const at::Tensor& input, | |
| const at::Tensor& weight, | |
| const at::Tensor& scales, | |
| const at::Tensor& table, | |
| const at::Tensor& table2, | |
| at::Tensor& workspace, | |
| const cute::int64_t num_bits, | |
| const cute::int64_t group_size, | |
| const cute::int64_t template_id, | |
| const cute::int64_t num_sms) | |
| { | |
| // Set the device of this function, primarily used when | |
| // we have multiple devices in the same process. | |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | |
| // Get the current CUDA stream, primarily used | |
| // to make CUDA Graphs work. | |
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
| // Squash the batch dimensions of the input tensor with its | |
| // next-to-last dimensions. | |
| const auto input_sizes = input.sizes().vec(); | |
| const auto input_2d = input.reshape({-1, input_sizes.back()}); | |
| auto output = at::empty( | |
| { | |
| input_2d.size(0), | |
| scales.size(0) | |
| }, | |
| at::TensorOptions() | |
| .dtype(input_2d.dtype()) | |
| .device(input_2d.device())); | |
| AT_DISPATCH_SWITCH( | |
| input.scalar_type(), | |
| "qgemm_raw_simple", | |
| AT_DISPATCH_CASE( | |
| at::ScalarType::Half, | |
| [&]() { | |
| RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::half_t); | |
| return; | |
| } | |
| ) | |
| AT_DISPATCH_CASE( | |
| at::ScalarType::BFloat16, | |
| [&]() { | |
| RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::bfloat16_t); | |
| return; | |
| } | |
| ) | |
| ); | |
| auto output_sizes = input_sizes; | |
| output_sizes.back() = scales.size(0); | |
| return output.reshape(output_sizes); | |
| } | |
| at::Tensor | |
| apply_hadamard(const at::Tensor& input, | |
| const cute::int64_t hadamard_size) | |
| { | |
| auto input_sizes = input.sizes(); | |
| auto flat_input = input.reshape({-1, hadamard_size}); | |
| auto had_input = hadamard_transform( | |
| flat_input, false | |
| ); | |
| return had_input.reshape(input_sizes); | |
| } | |
| at::Tensor | |
| qgemm_raw_simple_hadamard(const at::Tensor& input, | |
| const at::Tensor& weight, | |
| const at::Tensor& scales, | |
| const at::Tensor& table, | |
| const at::Tensor& table2, | |
| at::Tensor& workspace, | |
| const cute::int64_t num_bits, | |
| const cute::int64_t group_size, | |
| const cute::int64_t hadamard_size, | |
| const cute::int64_t template_id, | |
| const cute::int64_t num_sms) | |
| { | |
| auto had_input = apply_hadamard( | |
| input, | |
| hadamard_size | |
| ); | |
| return qgemm_raw_simple( | |
| had_input, | |
| weight, | |
| scales, | |
| table, | |
| table2, | |
| workspace, | |
| num_bits, | |
| group_size, | |
| template_id, | |
| num_sms | |
| ); | |
| } | |
| // Op registration is in torch-ext/torch_binding.cpp so that kernel-builder | |
| // can mangle the namespace per build hash for multi-version coexistence. | |