Kernels
cuda
hadamard
galqiwi's picture
Re-initial source (hadamard kebab name, BSD-3-Clause)
e4a8c54 verified
"""Minimal usage example for the published kernel."""
import math
import torch
from kernels import get_kernel
hadamard = get_kernel("galqiwi/hadamard_transform_kernels", version=1)
x = torch.randn(4, 4096, device="cuda", dtype=torch.float16)
y = hadamard.hadamard_transform(x, scale=1.0 / math.sqrt(4096))
print(y.shape, y.dtype)