Instructions to use galqiwi/flute_kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use galqiwi/flute_kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("galqiwi/flute_kernels") - Notebooks
- Google Colab
- Kaggle
File size: 1,951 Bytes
67a5826 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | // borrowed from https://github.com/pytorch-labs/applied-ai/tree/main/kernels/cuda/inference/hadamard_transform
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
using namespace at::indexing;
template <at::ScalarType dtype>
void run_fht(void* a, void* out, uint32_t numel, uint32_t had_size, cudaStream_t stream);
constexpr bool is_power_of_two(uint32_t x) {
return x && !(x & (x - 1));
}
at::Tensor hadamard_transform(at::Tensor& in, bool inplace) {
auto dtype = in.scalar_type();
TORCH_CHECK(dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16, "Only fp16 and bf16 supported currently");
TORCH_CHECK(in.is_cuda());
const int had_size = in.size(-1);
TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)),
"Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size);
const auto res_shape = in.sizes();
at::Tensor x = in.reshape({-1, had_size});
auto numel = in.numel();
if (numel % 256 != 0) {
x = at::constant_pad_nd(x, at::IntArrayRef({0, 0, 0, static_cast<int64_t>((256 - numel % 256) / had_size)}), 0);
}
if (x.stride(-1) != 1) {
x = x.contiguous();
}
at::Tensor out = inplace ? x : at::empty_like(x);
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (dtype == at::ScalarType::Half) {
run_fht<at::ScalarType::Half>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);
} else {
run_fht<at::ScalarType::BFloat16>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);
}
if (numel % 256 != 0) {
out = out.index({Slice(0, numel / had_size)});
}
if (inplace && out.data_ptr() != in.data_ptr()) {
in.copy_(out.view(res_shape));
return in;
}
return out.reshape(res_shape);
}
|