File size: 273 Bytes
67a5826
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
"""Smoke example for galqiwi/flute_kernels."""

import torch
from kernels import get_kernel

flute = get_kernel("galqiwi/flute_kernels", version=1)

x = torch.randn(4, 4096, device="cuda", dtype=torch.float16)
y = flute.hadamard_transform(x, False)
print(y.shape, y.dtype)