File size: 2,664 Bytes
adc211c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import torch
from kernels.benchmark import Benchmark
def rwkv_wkv_reference(
w: torch.Tensor, u: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
B, T, C = k.shape
device = k.device
dtype = k.dtype
y = torch.zeros(B, T, C, device=device, dtype=dtype)
# State: accumulated numerator, denominator, and max exponent
aa = torch.zeros(B, C, device=device, dtype=torch.float32)
bb = torch.zeros(B, C, device=device, dtype=torch.float32)
pp = torch.full((B, C), -1e38, device=device, dtype=torch.float32)
w = w.float()
u = u.float()
for t in range(T):
kt = k[:, t, :].float() # [B, C]
vt = v[:, t, :].float() # [B, C]
# Output computation
ww = u + kt
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
y[:, t, :] = ((e1 * aa + e2 * vt) / (e1 * bb + e2)).to(dtype)
# State update (note: w + pp, not pp - w)
ww = w + pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
aa = e1 * aa + e2 * vt
bb = e1 * bb + e2
pp = p
return y
class RwkvBenchmark(Benchmark):
seed: int = 42
def setup(self):
B, T, C = 2, 64, 256
self.w = torch.randn(
C, device=self.device, dtype=torch.float32
).abs() # Decay should be positive
self.u = torch.randn(C, device=self.device, dtype=torch.float32)
self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)
def benchmark_base(self):
self.out.zero_()
self.kernel.forward(self.w, self.u, self.k, self.v, self.out)
def verify_base(self) -> torch.Tensor:
return rwkv_wkv_reference(self.w, self.u, self.k, self.v)
def setup_large(self):
B, T, C = 8, 256, 512
self.w = torch.randn(C, device=self.device, dtype=torch.float32).abs()
self.u = torch.randn(C, device=self.device, dtype=torch.float32)
self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)
def benchmark_large(self):
self.out.zero_()
self.kernel.forward(self.w, self.u, self.k, self.v, self.out)
def verify_large(self) -> torch.Tensor:
return rwkv_wkv_reference(self.w, self.u, self.k, self.v)
|