flute_kernels / flute_cuda /hadamard_transform.cpp
galqiwi's picture
Initial source: FLUTE kernel scaffold (vendored CUTLASS, split TUs)
67a5826 verified
// borrowed from https://github.com/pytorch-labs/applied-ai/tree/main/kernels/cuda/inference/hadamard_transform
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
using namespace at::indexing;
template <at::ScalarType dtype>
void run_fht(void* a, void* out, uint32_t numel, uint32_t had_size, cudaStream_t stream);
constexpr bool is_power_of_two(uint32_t x) {
return x && !(x & (x - 1));
}
at::Tensor hadamard_transform(at::Tensor& in, bool inplace) {
auto dtype = in.scalar_type();
TORCH_CHECK(dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16, "Only fp16 and bf16 supported currently");
TORCH_CHECK(in.is_cuda());
const int had_size = in.size(-1);
TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)),
"Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size);
const auto res_shape = in.sizes();
at::Tensor x = in.reshape({-1, had_size});
auto numel = in.numel();
if (numel % 256 != 0) {
x = at::constant_pad_nd(x, at::IntArrayRef({0, 0, 0, static_cast<int64_t>((256 - numel % 256) / had_size)}), 0);
}
if (x.stride(-1) != 1) {
x = x.contiguous();
}
at::Tensor out = inplace ? x : at::empty_like(x);
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (dtype == at::ScalarType::Half) {
run_fht<at::ScalarType::Half>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);
} else {
run_fht<at::ScalarType::BFloat16>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);
}
if (numel % 256 != 0) {
out = out.index({Slice(0, numel / had_size)});
}
if (inplace && out.data_ptr() != in.data_ptr()) {
in.copy_(out.view(res_shape));
return in;
}
return out.reshape(res_shape);
}