File size: 2,664 Bytes
adc211c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch

from kernels.benchmark import Benchmark


def rwkv_wkv_reference(
    w: torch.Tensor, u: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
    B, T, C = k.shape
    device = k.device
    dtype = k.dtype

    y = torch.zeros(B, T, C, device=device, dtype=dtype)

    # State: accumulated numerator, denominator, and max exponent
    aa = torch.zeros(B, C, device=device, dtype=torch.float32)
    bb = torch.zeros(B, C, device=device, dtype=torch.float32)
    pp = torch.full((B, C), -1e38, device=device, dtype=torch.float32)

    w = w.float()
    u = u.float()

    for t in range(T):
        kt = k[:, t, :].float()  # [B, C]
        vt = v[:, t, :].float()  # [B, C]

        # Output computation
        ww = u + kt
        p = torch.maximum(pp, ww)
        e1 = torch.exp(pp - p)
        e2 = torch.exp(ww - p)
        y[:, t, :] = ((e1 * aa + e2 * vt) / (e1 * bb + e2)).to(dtype)

        # State update (note: w + pp, not pp - w)
        ww = w + pp
        p = torch.maximum(ww, kt)
        e1 = torch.exp(ww - p)
        e2 = torch.exp(kt - p)
        aa = e1 * aa + e2 * vt
        bb = e1 * bb + e2
        pp = p

    return y


class RwkvBenchmark(Benchmark):
    seed: int = 42

    def setup(self):
        B, T, C = 2, 64, 256

        self.w = torch.randn(
            C, device=self.device, dtype=torch.float32
        ).abs()  # Decay should be positive
        self.u = torch.randn(C, device=self.device, dtype=torch.float32)
        self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
        self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
        self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)

    def benchmark_base(self):
        self.out.zero_()
        self.kernel.forward(self.w, self.u, self.k, self.v, self.out)

    def verify_base(self) -> torch.Tensor:
        return rwkv_wkv_reference(self.w, self.u, self.k, self.v)

    def setup_large(self):
        B, T, C = 8, 256, 512

        self.w = torch.randn(C, device=self.device, dtype=torch.float32).abs()
        self.u = torch.randn(C, device=self.device, dtype=torch.float32)
        self.k = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
        self.v = torch.randn(B, T, C, device=self.device, dtype=torch.float32) * 0.1
        self.out = torch.zeros(B, T, C, device=self.device, dtype=torch.float32)

    def benchmark_large(self):
        self.out.zero_()
        self.kernel.forward(self.w, self.u, self.k, self.v, self.out)

    def verify_large(self) -> torch.Tensor:
        return rwkv_wkv_reference(self.w, self.u, self.k, self.v)