File size: 1,951 Bytes
67a5826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// borrowed from https://github.com/pytorch-labs/applied-ai/tree/main/kernels/cuda/inference/hadamard_transform

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

using namespace at::indexing;

template <at::ScalarType dtype>
void run_fht(void* a, void* out, uint32_t numel, uint32_t had_size, cudaStream_t stream);

constexpr bool is_power_of_two(uint32_t x) {
    return x && !(x & (x - 1));
}

at::Tensor hadamard_transform(at::Tensor& in, bool inplace) {
    auto dtype = in.scalar_type();
    TORCH_CHECK(dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16, "Only fp16 and bf16 supported currently");
    TORCH_CHECK(in.is_cuda());
    
    const int had_size = in.size(-1);
    TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)),
        "Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size);
    
    const auto res_shape = in.sizes();
    at::Tensor x = in.reshape({-1, had_size});
    
    auto numel = in.numel();
    if (numel % 256 != 0) {
        x = at::constant_pad_nd(x, at::IntArrayRef({0, 0, 0, static_cast<int64_t>((256 - numel % 256) / had_size)}), 0);
    }
    
    if (x.stride(-1) != 1) {
        x = x.contiguous();
    }
    at::Tensor out = inplace ? x : at::empty_like(x);

    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
    auto stream = at::cuda::getCurrentCUDAStream().stream();

    if (dtype == at::ScalarType::Half) {
        run_fht<at::ScalarType::Half>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);
    } else {
        run_fht<at::ScalarType::BFloat16>(x.data_ptr(), out.data_ptr(), x.numel(), had_size, stream);
    }

    if (numel % 256 != 0) {
        out = out.index({Slice(0, numel / had_size)});
    }

    if (inplace && out.data_ptr() != in.data_ptr()) {
        in.copy_(out.view(res_shape));
        return in;
    }
    return out.reshape(res_shape);
}