HARE / birwkv7.py
SixOpen's picture
Update birwkv7.py
f304ad1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
_TRITON_AVAILABLE = False
try:
import triton
import triton.language as tl
@triton.jit
def _wkv7_fwd_kernel(
R, K, V, DECAY, A, O,
STATE_OUT, STATE_IN,
sab_scale, T,
stride_b, stride_t, stride_h,
H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr,
RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr,
):
pid = tl.program_id(0)
b_idx = pid // H
h_idx = pid % H
base = b_idx * stride_b + h_idx * stride_h
di = tl.arange(0, BLOCK_D)
dj = tl.arange(0, BLOCK_D)
mask_i = di < D
mask_j = dj < D
if HAS_INIT_STATE:
s_off = b_idx * (H * D * D) + h_idx * (D * D)
state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :]
state_mask = mask_i[:, None] & mask_j[None, :]
state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32)
else:
state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32)
for t in range(T):
t_off = base + t * stride_t
kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32)
rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32)
at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
sa = tl.sum(state * (-kt)[None, :], axis=1)
ka = kt * at
sab = sa[:, None] * ka[None, :]
state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :]
state = tl.minimum(tl.maximum(state, -10.0), 10.0)
out_t = tl.sum(state * rt[None, :], axis=1)
tl.store(O + t_off + di, out_t, mask=mask_i)
if RETURN_STATE:
s_off = b_idx * (H * D * D) + h_idx * (D * D)
state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :]
state_mask = mask_i[:, None] & mask_j[None, :]
tl.store(state_ptrs, state, mask=state_mask)
def _wkv7_scan_triton(r, decay, k, v, a, sab_scale):
B, T, H, D = r.shape
r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)]
o = torch.empty_like(r)
stride_b, stride_t, stride_h = T * H * D, H * D, D
BLOCK_D = triton.next_power_of_2(D)
_wkv7_fwd_kernel[(B * H,)](
r, k, v, decay, a, o,
None, None,
float(sab_scale), T,
stride_b, stride_t, stride_h,
H=H, D=D, BLOCK_D=BLOCK_D,
RETURN_STATE=False, HAS_INIT_STATE=False,
)
return o
if torch.cuda.is_available():
_TRITON_AVAILABLE = True
except Exception:
pass
_FLA_AVAILABLE = False
try:
import torch.distributed.tensor as _tdt
if not hasattr(_tdt, 'Replicate'):
try:
from torch.distributed._tensor import Replicate as _R, Shard as _S
_tdt.Replicate = _R; _tdt.Shard = _S
except ImportError:
pass
if not hasattr(_tdt, 'Placement'):
try:
from torch.distributed._tensor.placement_types import Placement as _P
_tdt.Placement = _P
except ImportError:
pass
if not hasattr(_tdt, 'distribute_module'):
_tdt.distribute_module = lambda *a, **kw: None
from fla.ops.rwkv7 import chunk_rwkv7 as _fla_chunk_rwkv7
if torch.cuda.is_available():
_test_r = torch.randn(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16, requires_grad=True)
_test_w = -torch.ones(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16)
_test_o, _ = _fla_chunk_rwkv7(_test_r, _test_w, _test_r, _test_r, _test_r, _test_r,
head_first=False)
_test_o.sum().backward()
if not _test_r.grad.isnan().any():
_FLA_AVAILABLE = True
del _test_r, _test_w, _test_o
torch.cuda.empty_cache()
else:
_FLA_AVAILABLE = True
except Exception:
pass
class BiRWKV7Layer(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
assert hidden_size % num_heads == 0
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.mu_r = nn.Parameter(torch.zeros(hidden_size))
self.mu_w = nn.Parameter(torch.zeros(hidden_size))
self.mu_k = nn.Parameter(torch.zeros(hidden_size))
self.mu_v = nn.Parameter(torch.zeros(hidden_size))
self.mu_a = nn.Parameter(torch.zeros(hidden_size))
self.mu_g = nn.Parameter(torch.zeros(hidden_size))
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_k = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_v = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_w = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_a = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_g = nn.Linear(hidden_size, hidden_size, bias=False)
self.sab_gate = nn.Parameter(torch.tensor(-5.0))
self.group_norm = nn.GroupNorm(num_heads, hidden_size)
self.W_o = nn.Linear(hidden_size, hidden_size, bias=False)
nn.init.normal_(self.W_w.weight, std=0.01)
nn.init.normal_(self.W_a.weight, std=0.01)
nn.init.normal_(self.W_g.weight, std=0.02)
def _token_shift(self, x):
x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
def mix(mu):
return x + (x_prev - x) * torch.sigmoid(mu)
return {
'r': mix(self.mu_r), 'w': mix(self.mu_w),
'k': mix(self.mu_k), 'v': mix(self.mu_v),
'a': mix(self.mu_a), 'g': mix(self.mu_g),
}
def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale):
B, T, H, D = r.shape
orig_dtype = r.dtype
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
k_scaled = k * (D ** -0.5)
w_log = -0.6065306597633104 * torch.sigmoid(w)
a_sig = torch.sigmoid(a)
a_fla = -k_scaled
b_fla = sab_scale * k_scaled * a_sig
o, _ = _fla_chunk_rwkv7(r, k_scaled, v, a_fla, b_fla, log_w=w_log, scale=1.0, head_first=False)
return o.to(orig_dtype)
def _wkv7_scan_python(self, r, w, k, v, a, sab_scale):
B, T, H, D = r.shape
orig_dtype = r.dtype
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
k = k * (D ** -0.5)
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
a = torch.sigmoid(a)
state = torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
outputs = []
for t in range(T):
if t > 0 and t % 16 == 0:
state = state.detach()
kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
sa = torch.einsum('bhij,bhj->bhi', state, -kt)
sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
state = state * dt.unsqueeze(-2) + sab_scale * sab + torch.einsum('bhi,bhj->bhij', vt, kt)
state = state.clamp(-10.0, 10.0)
outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
return torch.stack(outputs, dim=1).to(orig_dtype)
def _wkv7_scan(self, r, w, k, v, a, sab_scale):
if _TRITON_AVAILABLE and r.is_cuda:
B, T, H, D = r.shape
orig_dtype = r.dtype
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
k = k * (D ** -0.5)
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
a = torch.sigmoid(a)
return _wkv7_scan_triton(r, decay, k, v, a, sab_scale).to(orig_dtype)
if _FLA_AVAILABLE and r.is_cuda:
return self._wkv7_scan_fla(r, w, k, v, a, sab_scale)
return self._wkv7_scan_python(r, w, k, v, a, sab_scale)
def forward(self, x, attention_mask=None, **kwargs):
B, T, C = x.shape
H, D = self.num_heads, self.head_size
mixed = self._token_shift(x)
r = self.W_r(mixed['r']).view(B, T, H, D)
w = self.W_w(mixed['w']).view(B, T, H, D)
k = self.W_k(mixed['k']).view(B, T, H, D)
v = self.W_v(mixed['v']).view(B, T, H, D)
a = self.W_a(mixed['a']).view(B, T, H, D)
g = torch.sigmoid(self.W_g(mixed['g']))
sab_scale = torch.sigmoid(self.sab_gate)
out_fwd = self._wkv7_scan(r, w, k, v, a, sab_scale)
out_bwd = self._wkv7_scan(
r.flip(1), w.flip(1), k.flip(1), v.flip(1), a.flip(1), sab_scale
).flip(1)
out = (out_fwd + out_bwd).reshape(B, T, C) * 0.5
out = self.group_norm(out.transpose(1, 2)).transpose(1, 2)
out = self.W_o(out * g)
return out, None
def init_from_attention(birwkv, attn_module):
q_proj = k_proj = v_proj = o_proj = None
if hasattr(attn_module, 'Wqkv'):
fused = attn_module.Wqkv.weight.data
C = fused.shape[1]
q_proj, k_proj, v_proj = fused[:C], fused[C:2*C], fused[2*C:]
else:
for name in ['q_proj', 'query', 'W_q', 'wq']:
if hasattr(attn_module, name):
q_proj = getattr(attn_module, name).weight.data
break
for name in ['k_proj', 'key', 'W_k', 'wk']:
if hasattr(attn_module, name):
k_proj = getattr(attn_module, name).weight.data
break
for name in ['v_proj', 'value', 'W_v', 'wv']:
if hasattr(attn_module, name):
v_proj = getattr(attn_module, name).weight.data
break
for name in ['Wo', 'out_proj', 'o_proj', 'dense', 'W_o', 'wo']:
if hasattr(attn_module, name):
o_proj = getattr(attn_module, name).weight.data
break
transferred = []
for src, dst, label in [
(q_proj, birwkv.W_r, 'Q->R'),
(k_proj, birwkv.W_k, 'K->K'),
(v_proj, birwkv.W_v, 'V->V'),
(o_proj, birwkv.W_o, 'O->O'),
]:
if src is not None:
dst.weight.data.copy_(src)
transferred.append(label)
return transferred