Instructions to use galqiwi/flute_kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use galqiwi/flute_kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("galqiwi/flute_kernels") - Notebooks
- Google Colab
- Kaggle
File size: 8,242 Bytes
67a5826 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | #include <cuda_runtime.h>
#include <torch/library.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cute/numeric/integral_constant.hpp"
at::Tensor
hadamard_transform(at::Tensor& in,
bool inplace);
template <
typename T,
typename TQ,
typename T2,
typename NumBits,
typename GroupSize
>
void
_qgemm_raw(int M,
int N,
int K,
int P,
const T * const __restrict__ A,
const TQ* const __restrict__ Q,
T * __restrict__ D,
const T * const __restrict__ S,
const T * const __restrict__ QM,
const T2* const __restrict__ QM2,
void* __restrict__ workspace,
const int template_id,
const int num_sms,
const cudaStream_t stream);
template <
typename T,
typename NumBits,
typename GroupSize
>
void
qgemm_raw(const at::Tensor& input,
const at::Tensor& weight,
at::Tensor& output,
const at::Tensor& scales,
const at::Tensor& table,
const at::Tensor& table2,
at::Tensor& workspace,
const int template_id,
const int num_sms,
const cudaStream_t stream)
{
using namespace cute;
using TQ = cute::uint16_t;
using T2 = conditional_t<is_same_v<T, half_t>, __half2, __nv_bfloat162>;
_qgemm_raw<
T,
TQ,
T2,
NumBits,
GroupSize
> (
output.size(0), // M
output.size(1), // N
input .size(1), // K
weight.size(0), // P
reinterpret_cast<const T *>(input .data_ptr()),
reinterpret_cast<const TQ*>(weight .data_ptr()),
reinterpret_cast< T *>(output .data_ptr()),
reinterpret_cast<const T *>(scales .data_ptr()),
reinterpret_cast<const T *>(table .data_ptr()),
reinterpret_cast<const T2*>(table2 .data_ptr()),
reinterpret_cast< void*>(workspace.data_ptr()),
template_id,
num_sms,
stream);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
at::Tensor
qgemm_raw_simple(const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& scales,
const at::Tensor& table,
const at::Tensor& table2,
at::Tensor& workspace,
const cute::int64_t num_bits,
const cute::int64_t group_size,
const cute::int64_t template_id,
const cute::int64_t num_sms)
{
// Set the device of this function, primarily used when
// we have multiple devices in the same process.
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
// Get the current CUDA stream, primarily used
// to make CUDA Graphs work.
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Squash the batch dimensions of the input tensor with its
// next-to-last dimensions.
const auto input_sizes = input.sizes().vec();
const auto input_2d = input.reshape({-1, input_sizes.back()});
auto output = at::empty(
{
input_2d.size(0),
scales.size(0)
},
at::TensorOptions()
.dtype(input_2d.dtype())
.device(input_2d.device()));
#define RUN_QGEMM_RAW(T, NUM_BITS, GROUP_SIZE) \
do { \
qgemm_raw< \
T, \
cute::Int<NUM_BITS>, \
cute::Int<GROUP_SIZE> \
> ( \
input_2d, \
weight, \
output, \
scales, \
table, \
table2, \
workspace, \
template_id, \
num_sms, \
stream); \
} while (false)
#define RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, NUM_BITS) \
do { \
switch (group_size) \
{ \
case 64: \
RUN_QGEMM_RAW(T, NUM_BITS, 64); \
break; \
case 128: \
RUN_QGEMM_RAW(T, NUM_BITS, 128); \
break; \
case 256: \
RUN_QGEMM_RAW(T, NUM_BITS, 256); \
break; \
default: \
AT_ERROR("Unsupported `group_size`"); \
} \
} while (false)
#define RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(T) \
do { \
switch (num_bits) \
{ \
case 2: \
RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 2); \
break; \
case 3: \
RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 3); \
break; \
case 4: \
RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 4); \
break; \
default: \
AT_ERROR("Unsupported `num_bits`"); \
} \
} while (false)
AT_DISPATCH_SWITCH(
input.scalar_type(),
"qgemm_raw_simple",
AT_DISPATCH_CASE(
at::ScalarType::Half,
[&]() {
RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::half_t);
return;
}
)
AT_DISPATCH_CASE(
at::ScalarType::BFloat16,
[&]() {
RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::bfloat16_t);
return;
}
)
);
auto output_sizes = input_sizes;
output_sizes.back() = scales.size(0);
return output.reshape(output_sizes);
}
at::Tensor
apply_hadamard(const at::Tensor& input,
const cute::int64_t hadamard_size)
{
auto input_sizes = input.sizes();
auto flat_input = input.reshape({-1, hadamard_size});
auto had_input = hadamard_transform(
flat_input, false
);
return had_input.reshape(input_sizes);
}
at::Tensor
qgemm_raw_simple_hadamard(const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& scales,
const at::Tensor& table,
const at::Tensor& table2,
at::Tensor& workspace,
const cute::int64_t num_bits,
const cute::int64_t group_size,
const cute::int64_t hadamard_size,
const cute::int64_t template_id,
const cute::int64_t num_sms)
{
auto had_input = apply_hadamard(
input,
hadamard_size
);
return qgemm_raw_simple(
had_input,
weight,
scales,
table,
table2,
workspace,
num_bits,
group_size,
template_id,
num_sms
);
}
// Op registration is in torch-ext/torch_binding.cpp so that kernel-builder
// can mangle the namespace per build hash for multi-version coexistence.
|