danieldk HF Staff commited on
Commit
adc211c
·
verified ·
1 Parent(s): 5f8a57e

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +81 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from kernels.benchmark import Benchmark
4
+
5
+
6
+ def rwkv_wkv_reference(
7
+ w: torch.Tensor, u: torch.Tensor, k: torch.Tensor, v: torch.Tensor
8
+ ) -> torch.Tensor:
9
+ B, T, C = k.shape
10
+ device = k.device
11
+ dtype = k.dtype
12
+
13
+ y = torch.zeros(B, T, C, device=device, dtype=dtype)
14
+
15
+ # State: accumulated numerator, denominator, and max exponent
16
+ aa = torch.zeros(B, C, device=device, dtype=torch.float32)
17
+ bb = torch.zeros(B, C, device=device, dtype=torch.float32)
18
+ pp = torch.full((B, C), -1e38, device=device, dtype=torch.float32)
19
+
20
+ w = w.float()
21
+ u = u.float()
22
+
23
+ for t in range(T):
24
+ kt = k[:, t, :].float() # [B, C]
25
+ vt = v[:, t, :].float() # [B, C]
26
+
27
+ # Output computation
28
+ ww = u + kt
29
+ p = torch.maximum(pp, ww)
30
+ e1 = torch.exp(pp - p)
31
+ e2 = torch.exp(ww - p)
32
+ y[:, t, :] = ((e1 * aa + e2 * vt) / (e1 * bb + e2)).to(dtype)
33
+
34
+ # State update (note: w + pp, not pp - w)
35
+ ww = w + pp
36
+ p = torch.maximum(ww, kt)
37
+ e1 = torch.exp(ww - p)
38
+ e2 = torch.exp(kt - p)
39
+ aa = e1 * aa + e2 * vt
40
+ bb = e1 * bb + e2
41
+ pp = p
42
+
43
+ return y
44
+
45
+
46
+ class RwkvBenchmark(Benchmark):
47
+ seed: int = 42
48
+
49
+ def setup(self):
50
+ B, T, C = 2, 64, 256
51
+
52
+ self.w = torch.randn(
53
+ C, device=self.device, dtype=torch.float32
54
+ ).abs() # Decay should be positive
55
+ self.u = torch.randn(C, device=self.device, dtype=torch.float32)
56
+ self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
57
+ self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
58
+ self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)
59
+
60
+ def benchmark_base(self):
61
+ self.out.zero_()
62
+ self.kernel.forward(self.w, self.u, self.k, self.v, self.out)
63
+
64
+ def verify_base(self) -> torch.Tensor:
65
+ return rwkv_wkv_reference(self.w, self.u, self.k, self.v)
66
+
67
+ def setup_large(self):
68
+ B, T, C = 8, 256, 512
69
+
70
+ self.w = torch.randn(C, device=self.device, dtype=torch.float32).abs()
71
+ self.u = torch.randn(C, device=self.device, dtype=torch.float32)
72
+ self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
73
+ self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
74
+ self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)
75
+
76
+ def benchmark_large(self):
77
+ self.out.zero_()
78
+ self.kernel.forward(self.w, self.u, self.k, self.v, self.out)
79
+
80
+ def verify_large(self) -> torch.Tensor:
81
+ return rwkv_wkv_reference(self.w, self.u, self.k, self.v)