SixOpen commited on
Commit
f304ad1
·
verified ·
1 Parent(s): 7d6b779

Update birwkv7.py

Browse files

Inline Triton kernel and proper FLA arg order

Files changed (1) hide show
  1. birwkv7.py +86 -2
birwkv7.py CHANGED
@@ -2,6 +2,82 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  _FLA_AVAILABLE = False
6
  try:
7
  import torch.distributed.tensor as _tdt
@@ -83,13 +159,13 @@ class BiRWKV7Layer(nn.Module):
83
  def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale):
84
  B, T, H, D = r.shape
85
  orig_dtype = r.dtype
86
- r, w, k, v, a = [x.bfloat16() for x in (r, w, k, v, a)]
87
  k_scaled = k * (D ** -0.5)
88
  w_log = -0.6065306597633104 * torch.sigmoid(w)
89
  a_sig = torch.sigmoid(a)
90
  a_fla = -k_scaled
91
  b_fla = sab_scale * k_scaled * a_sig
92
- o, _ = _fla_chunk_rwkv7(r, w_log, k_scaled, v, a_fla, b_fla, scale=1.0)
93
  return o.to(orig_dtype)
94
 
95
  def _wkv7_scan_python(self, r, w, k, v, a, sab_scale):
@@ -120,6 +196,14 @@ class BiRWKV7Layer(nn.Module):
120
  return torch.stack(outputs, dim=1).to(orig_dtype)
121
 
122
  def _wkv7_scan(self, r, w, k, v, a, sab_scale):
 
 
 
 
 
 
 
 
123
  if _FLA_AVAILABLE and r.is_cuda:
124
  return self._wkv7_scan_fla(r, w, k, v, a, sab_scale)
125
  return self._wkv7_scan_python(r, w, k, v, a, sab_scale)
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ _TRITON_AVAILABLE = False
6
+ try:
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ @triton.jit
11
+ def _wkv7_fwd_kernel(
12
+ R, K, V, DECAY, A, O,
13
+ STATE_OUT, STATE_IN,
14
+ sab_scale, T,
15
+ stride_b, stride_t, stride_h,
16
+ H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr,
17
+ RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr,
18
+ ):
19
+ pid = tl.program_id(0)
20
+ b_idx = pid // H
21
+ h_idx = pid % H
22
+ base = b_idx * stride_b + h_idx * stride_h
23
+
24
+ di = tl.arange(0, BLOCK_D)
25
+ dj = tl.arange(0, BLOCK_D)
26
+ mask_i = di < D
27
+ mask_j = dj < D
28
+
29
+ if HAS_INIT_STATE:
30
+ s_off = b_idx * (H * D * D) + h_idx * (D * D)
31
+ state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :]
32
+ state_mask = mask_i[:, None] & mask_j[None, :]
33
+ state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32)
34
+ else:
35
+ state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32)
36
+
37
+ for t in range(T):
38
+ t_off = base + t * stride_t
39
+ kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
40
+ vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32)
41
+ rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
42
+ dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32)
43
+ at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
44
+
45
+ sa = tl.sum(state * (-kt)[None, :], axis=1)
46
+ ka = kt * at
47
+ sab = sa[:, None] * ka[None, :]
48
+ state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :]
49
+ state = tl.minimum(tl.maximum(state, -10.0), 10.0)
50
+
51
+ out_t = tl.sum(state * rt[None, :], axis=1)
52
+ tl.store(O + t_off + di, out_t, mask=mask_i)
53
+
54
+ if RETURN_STATE:
55
+ s_off = b_idx * (H * D * D) + h_idx * (D * D)
56
+ state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :]
57
+ state_mask = mask_i[:, None] & mask_j[None, :]
58
+ tl.store(state_ptrs, state, mask=state_mask)
59
+
60
+ def _wkv7_scan_triton(r, decay, k, v, a, sab_scale):
61
+ B, T, H, D = r.shape
62
+ r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)]
63
+ o = torch.empty_like(r)
64
+ stride_b, stride_t, stride_h = T * H * D, H * D, D
65
+ BLOCK_D = triton.next_power_of_2(D)
66
+ _wkv7_fwd_kernel[(B * H,)](
67
+ r, k, v, decay, a, o,
68
+ None, None,
69
+ float(sab_scale), T,
70
+ stride_b, stride_t, stride_h,
71
+ H=H, D=D, BLOCK_D=BLOCK_D,
72
+ RETURN_STATE=False, HAS_INIT_STATE=False,
73
+ )
74
+ return o
75
+
76
+ if torch.cuda.is_available():
77
+ _TRITON_AVAILABLE = True
78
+ except Exception:
79
+ pass
80
+
81
  _FLA_AVAILABLE = False
82
  try:
83
  import torch.distributed.tensor as _tdt
 
159
  def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale):
160
  B, T, H, D = r.shape
161
  orig_dtype = r.dtype
162
+ r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
163
  k_scaled = k * (D ** -0.5)
164
  w_log = -0.6065306597633104 * torch.sigmoid(w)
165
  a_sig = torch.sigmoid(a)
166
  a_fla = -k_scaled
167
  b_fla = sab_scale * k_scaled * a_sig
168
+ o, _ = _fla_chunk_rwkv7(r, k_scaled, v, a_fla, b_fla, log_w=w_log, scale=1.0, head_first=False)
169
  return o.to(orig_dtype)
170
 
171
  def _wkv7_scan_python(self, r, w, k, v, a, sab_scale):
 
196
  return torch.stack(outputs, dim=1).to(orig_dtype)
197
 
198
  def _wkv7_scan(self, r, w, k, v, a, sab_scale):
199
+ if _TRITON_AVAILABLE and r.is_cuda:
200
+ B, T, H, D = r.shape
201
+ orig_dtype = r.dtype
202
+ r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
203
+ k = k * (D ** -0.5)
204
+ decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
205
+ a = torch.sigmoid(a)
206
+ return _wkv7_scan_triton(r, decay, k, v, a, sab_scale).to(orig_dtype)
207
  if _FLA_AVAILABLE and r.is_cuda:
208
  return self._wkv7_scan_fla(r, w, k, v, a, sab_scale)
209
  return self._wkv7_scan_python(r, w, k, v, a, sab_scale)