Update birwkv7.py
Browse filesInline Triton kernel and proper FLA arg order
- birwkv7.py +86 -2
birwkv7.py
CHANGED
|
@@ -2,6 +2,82 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
_FLA_AVAILABLE = False
|
| 6 |
try:
|
| 7 |
import torch.distributed.tensor as _tdt
|
|
@@ -83,13 +159,13 @@ class BiRWKV7Layer(nn.Module):
|
|
| 83 |
def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale):
|
| 84 |
B, T, H, D = r.shape
|
| 85 |
orig_dtype = r.dtype
|
| 86 |
-
r, w, k, v, a = [x.
|
| 87 |
k_scaled = k * (D ** -0.5)
|
| 88 |
w_log = -0.6065306597633104 * torch.sigmoid(w)
|
| 89 |
a_sig = torch.sigmoid(a)
|
| 90 |
a_fla = -k_scaled
|
| 91 |
b_fla = sab_scale * k_scaled * a_sig
|
| 92 |
-
o, _ = _fla_chunk_rwkv7(r,
|
| 93 |
return o.to(orig_dtype)
|
| 94 |
|
| 95 |
def _wkv7_scan_python(self, r, w, k, v, a, sab_scale):
|
|
@@ -120,6 +196,14 @@ class BiRWKV7Layer(nn.Module):
|
|
| 120 |
return torch.stack(outputs, dim=1).to(orig_dtype)
|
| 121 |
|
| 122 |
def _wkv7_scan(self, r, w, k, v, a, sab_scale):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
if _FLA_AVAILABLE and r.is_cuda:
|
| 124 |
return self._wkv7_scan_fla(r, w, k, v, a, sab_scale)
|
| 125 |
return self._wkv7_scan_python(r, w, k, v, a, sab_scale)
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
+
_TRITON_AVAILABLE = False
|
| 6 |
+
try:
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
@triton.jit
|
| 11 |
+
def _wkv7_fwd_kernel(
|
| 12 |
+
R, K, V, DECAY, A, O,
|
| 13 |
+
STATE_OUT, STATE_IN,
|
| 14 |
+
sab_scale, T,
|
| 15 |
+
stride_b, stride_t, stride_h,
|
| 16 |
+
H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr,
|
| 17 |
+
RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr,
|
| 18 |
+
):
|
| 19 |
+
pid = tl.program_id(0)
|
| 20 |
+
b_idx = pid // H
|
| 21 |
+
h_idx = pid % H
|
| 22 |
+
base = b_idx * stride_b + h_idx * stride_h
|
| 23 |
+
|
| 24 |
+
di = tl.arange(0, BLOCK_D)
|
| 25 |
+
dj = tl.arange(0, BLOCK_D)
|
| 26 |
+
mask_i = di < D
|
| 27 |
+
mask_j = dj < D
|
| 28 |
+
|
| 29 |
+
if HAS_INIT_STATE:
|
| 30 |
+
s_off = b_idx * (H * D * D) + h_idx * (D * D)
|
| 31 |
+
state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :]
|
| 32 |
+
state_mask = mask_i[:, None] & mask_j[None, :]
|
| 33 |
+
state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32)
|
| 34 |
+
else:
|
| 35 |
+
state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32)
|
| 36 |
+
|
| 37 |
+
for t in range(T):
|
| 38 |
+
t_off = base + t * stride_t
|
| 39 |
+
kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
|
| 40 |
+
vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32)
|
| 41 |
+
rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
|
| 42 |
+
dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32)
|
| 43 |
+
at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
|
| 44 |
+
|
| 45 |
+
sa = tl.sum(state * (-kt)[None, :], axis=1)
|
| 46 |
+
ka = kt * at
|
| 47 |
+
sab = sa[:, None] * ka[None, :]
|
| 48 |
+
state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :]
|
| 49 |
+
state = tl.minimum(tl.maximum(state, -10.0), 10.0)
|
| 50 |
+
|
| 51 |
+
out_t = tl.sum(state * rt[None, :], axis=1)
|
| 52 |
+
tl.store(O + t_off + di, out_t, mask=mask_i)
|
| 53 |
+
|
| 54 |
+
if RETURN_STATE:
|
| 55 |
+
s_off = b_idx * (H * D * D) + h_idx * (D * D)
|
| 56 |
+
state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :]
|
| 57 |
+
state_mask = mask_i[:, None] & mask_j[None, :]
|
| 58 |
+
tl.store(state_ptrs, state, mask=state_mask)
|
| 59 |
+
|
| 60 |
+
def _wkv7_scan_triton(r, decay, k, v, a, sab_scale):
|
| 61 |
+
B, T, H, D = r.shape
|
| 62 |
+
r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)]
|
| 63 |
+
o = torch.empty_like(r)
|
| 64 |
+
stride_b, stride_t, stride_h = T * H * D, H * D, D
|
| 65 |
+
BLOCK_D = triton.next_power_of_2(D)
|
| 66 |
+
_wkv7_fwd_kernel[(B * H,)](
|
| 67 |
+
r, k, v, decay, a, o,
|
| 68 |
+
None, None,
|
| 69 |
+
float(sab_scale), T,
|
| 70 |
+
stride_b, stride_t, stride_h,
|
| 71 |
+
H=H, D=D, BLOCK_D=BLOCK_D,
|
| 72 |
+
RETURN_STATE=False, HAS_INIT_STATE=False,
|
| 73 |
+
)
|
| 74 |
+
return o
|
| 75 |
+
|
| 76 |
+
if torch.cuda.is_available():
|
| 77 |
+
_TRITON_AVAILABLE = True
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
_FLA_AVAILABLE = False
|
| 82 |
try:
|
| 83 |
import torch.distributed.tensor as _tdt
|
|
|
|
| 159 |
def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale):
|
| 160 |
B, T, H, D = r.shape
|
| 161 |
orig_dtype = r.dtype
|
| 162 |
+
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
|
| 163 |
k_scaled = k * (D ** -0.5)
|
| 164 |
w_log = -0.6065306597633104 * torch.sigmoid(w)
|
| 165 |
a_sig = torch.sigmoid(a)
|
| 166 |
a_fla = -k_scaled
|
| 167 |
b_fla = sab_scale * k_scaled * a_sig
|
| 168 |
+
o, _ = _fla_chunk_rwkv7(r, k_scaled, v, a_fla, b_fla, log_w=w_log, scale=1.0, head_first=False)
|
| 169 |
return o.to(orig_dtype)
|
| 170 |
|
| 171 |
def _wkv7_scan_python(self, r, w, k, v, a, sab_scale):
|
|
|
|
| 196 |
return torch.stack(outputs, dim=1).to(orig_dtype)
|
| 197 |
|
| 198 |
def _wkv7_scan(self, r, w, k, v, a, sab_scale):
|
| 199 |
+
if _TRITON_AVAILABLE and r.is_cuda:
|
| 200 |
+
B, T, H, D = r.shape
|
| 201 |
+
orig_dtype = r.dtype
|
| 202 |
+
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
|
| 203 |
+
k = k * (D ** -0.5)
|
| 204 |
+
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
|
| 205 |
+
a = torch.sigmoid(a)
|
| 206 |
+
return _wkv7_scan_triton(r, decay, k, v, a, sab_scale).to(orig_dtype)
|
| 207 |
if _FLA_AVAILABLE and r.is_cuda:
|
| 208 |
return self._wkv7_scan_fla(r, w, k, v, a, sab_scale)
|
| 209 |
return self._wkv7_scan_python(r, w, k, v, a, sab_scale)
|