Kernels
cuda
hadamard
File size: 334 Bytes
e4a8c54
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
"""Minimal usage example for the published kernel."""

import math

import torch
from kernels import get_kernel

hadamard = get_kernel("galqiwi/hadamard_transform_kernels", version=1)

x = torch.randn(4, 4096, device="cuda", dtype=torch.float16)
y = hadamard.hadamard_transform(x, scale=1.0 / math.sqrt(4096))
print(y.shape, y.dtype)