| | import torch |
| |
|
| | from kernels.benchmark import Benchmark |
| |
|
| |
|
| | def rwkv_wkv_reference( |
| | w: torch.Tensor, u: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| | ) -> torch.Tensor: |
| | B, T, C = k.shape |
| | device = k.device |
| | dtype = k.dtype |
| |
|
| | y = torch.zeros(B, T, C, device=device, dtype=dtype) |
| |
|
| | |
| | aa = torch.zeros(B, C, device=device, dtype=torch.float32) |
| | bb = torch.zeros(B, C, device=device, dtype=torch.float32) |
| | pp = torch.full((B, C), -1e38, device=device, dtype=torch.float32) |
| |
|
| | w = w.float() |
| | u = u.float() |
| |
|
| | for t in range(T): |
| | kt = k[:, t, :].float() |
| | vt = v[:, t, :].float() |
| |
|
| | |
| | ww = u + kt |
| | p = torch.maximum(pp, ww) |
| | e1 = torch.exp(pp - p) |
| | e2 = torch.exp(ww - p) |
| | y[:, t, :] = ((e1 * aa + e2 * vt) / (e1 * bb + e2)).to(dtype) |
| |
|
| | |
| | ww = w + pp |
| | p = torch.maximum(ww, kt) |
| | e1 = torch.exp(ww - p) |
| | e2 = torch.exp(kt - p) |
| | aa = e1 * aa + e2 * vt |
| | bb = e1 * bb + e2 |
| | pp = p |
| |
|
| | return y |
| |
|
| |
|
| | class RwkvBenchmark(Benchmark): |
| | seed: int = 42 |
| |
|
| | def setup(self): |
| | B, T, C = 2, 64, 256 |
| |
|
| | self.w = torch.randn( |
| | C, device=self.device, dtype=torch.float32 |
| | ).abs() |
| | self.u = torch.randn(C, device=self.device, dtype=torch.float32) |
| | self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| | self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| | self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32) |
| |
|
| | def benchmark_base(self): |
| | self.out.zero_() |
| | self.kernel.forward(self.w, self.u, self.k, self.v, self.out) |
| |
|
| | def verify_base(self) -> torch.Tensor: |
| | return rwkv_wkv_reference(self.w, self.u, self.k, self.v) |
| |
|
| | def setup_large(self): |
| | B, T, C = 8, 256, 512 |
| |
|
| | self.w = torch.randn(C, device=self.device, dtype=torch.float32).abs() |
| | self.u = torch.randn(C, device=self.device, dtype=torch.float32) |
| | self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| | self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1 |
| | self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32) |
| |
|
| | def benchmark_large(self): |
| | self.out.zero_() |
| | self.kernel.forward(self.w, self.u, self.k, self.v, self.out) |
| |
|
| | def verify_large(self) -> torch.Tensor: |
| | return rwkv_wkv_reference(self.w, self.u, self.k, self.v) |
| |
|