| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| _TRITON_AVAILABLE = False |
| try: |
| import triton |
| import triton.language as tl |
|
|
| @triton.jit |
| def _wkv7_fwd_kernel( |
| R, K, V, DECAY, A, O, |
| STATE_OUT, STATE_IN, |
| sab_scale, T, |
| stride_b, stride_t, stride_h, |
| H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr, |
| RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr, |
| ): |
| pid = tl.program_id(0) |
| b_idx = pid // H |
| h_idx = pid % H |
| base = b_idx * stride_b + h_idx * stride_h |
|
|
| di = tl.arange(0, BLOCK_D) |
| dj = tl.arange(0, BLOCK_D) |
| mask_i = di < D |
| mask_j = dj < D |
|
|
| if HAS_INIT_STATE: |
| s_off = b_idx * (H * D * D) + h_idx * (D * D) |
| state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :] |
| state_mask = mask_i[:, None] & mask_j[None, :] |
| state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32) |
| else: |
| state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32) |
|
|
| for t in range(T): |
| t_off = base + t * stride_t |
| kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) |
| vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32) |
| rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) |
| dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32) |
| at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) |
|
|
| sa = tl.sum(state * (-kt)[None, :], axis=1) |
| ka = kt * at |
| sab = sa[:, None] * ka[None, :] |
| state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :] |
| state = tl.minimum(tl.maximum(state, -10.0), 10.0) |
|
|
| out_t = tl.sum(state * rt[None, :], axis=1) |
| tl.store(O + t_off + di, out_t, mask=mask_i) |
|
|
| if RETURN_STATE: |
| s_off = b_idx * (H * D * D) + h_idx * (D * D) |
| state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :] |
| state_mask = mask_i[:, None] & mask_j[None, :] |
| tl.store(state_ptrs, state, mask=state_mask) |
|
|
| def _wkv7_scan_triton(r, decay, k, v, a, sab_scale): |
| B, T, H, D = r.shape |
| r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)] |
| o = torch.empty_like(r) |
| stride_b, stride_t, stride_h = T * H * D, H * D, D |
| BLOCK_D = triton.next_power_of_2(D) |
| _wkv7_fwd_kernel[(B * H,)]( |
| r, k, v, decay, a, o, |
| None, None, |
| float(sab_scale), T, |
| stride_b, stride_t, stride_h, |
| H=H, D=D, BLOCK_D=BLOCK_D, |
| RETURN_STATE=False, HAS_INIT_STATE=False, |
| ) |
| return o |
|
|
| if torch.cuda.is_available(): |
| _TRITON_AVAILABLE = True |
| except Exception: |
| pass |
|
|
| _FLA_AVAILABLE = False |
| try: |
| import torch.distributed.tensor as _tdt |
| if not hasattr(_tdt, 'Replicate'): |
| try: |
| from torch.distributed._tensor import Replicate as _R, Shard as _S |
| _tdt.Replicate = _R; _tdt.Shard = _S |
| except ImportError: |
| pass |
| if not hasattr(_tdt, 'Placement'): |
| try: |
| from torch.distributed._tensor.placement_types import Placement as _P |
| _tdt.Placement = _P |
| except ImportError: |
| pass |
| if not hasattr(_tdt, 'distribute_module'): |
| _tdt.distribute_module = lambda *a, **kw: None |
| from fla.ops.rwkv7 import chunk_rwkv7 as _fla_chunk_rwkv7 |
| if torch.cuda.is_available(): |
| _test_r = torch.randn(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16, requires_grad=True) |
| _test_w = -torch.ones(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16) |
| _test_o, _ = _fla_chunk_rwkv7(_test_r, _test_w, _test_r, _test_r, _test_r, _test_r, |
| head_first=False) |
| _test_o.sum().backward() |
| if not _test_r.grad.isnan().any(): |
| _FLA_AVAILABLE = True |
| del _test_r, _test_w, _test_o |
| torch.cuda.empty_cache() |
| else: |
| _FLA_AVAILABLE = True |
| except Exception: |
| pass |
|
|
|
|
| class BiRWKV7Layer(nn.Module): |
|
|
| def __init__(self, hidden_size, num_heads): |
| super().__init__() |
| assert hidden_size % num_heads == 0 |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.head_size = hidden_size // num_heads |
|
|
| self.mu_r = nn.Parameter(torch.zeros(hidden_size)) |
| self.mu_w = nn.Parameter(torch.zeros(hidden_size)) |
| self.mu_k = nn.Parameter(torch.zeros(hidden_size)) |
| self.mu_v = nn.Parameter(torch.zeros(hidden_size)) |
| self.mu_a = nn.Parameter(torch.zeros(hidden_size)) |
| self.mu_g = nn.Parameter(torch.zeros(hidden_size)) |
|
|
| self.W_r = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.W_k = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.W_v = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.W_w = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.W_a = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.W_g = nn.Linear(hidden_size, hidden_size, bias=False) |
|
|
| self.sab_gate = nn.Parameter(torch.tensor(-5.0)) |
|
|
| self.group_norm = nn.GroupNorm(num_heads, hidden_size) |
| self.W_o = nn.Linear(hidden_size, hidden_size, bias=False) |
|
|
| nn.init.normal_(self.W_w.weight, std=0.01) |
| nn.init.normal_(self.W_a.weight, std=0.01) |
| nn.init.normal_(self.W_g.weight, std=0.02) |
|
|
| def _token_shift(self, x): |
| x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) |
|
|
| def mix(mu): |
| return x + (x_prev - x) * torch.sigmoid(mu) |
|
|
| return { |
| 'r': mix(self.mu_r), 'w': mix(self.mu_w), |
| 'k': mix(self.mu_k), 'v': mix(self.mu_v), |
| 'a': mix(self.mu_a), 'g': mix(self.mu_g), |
| } |
|
|
| def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale): |
| B, T, H, D = r.shape |
| orig_dtype = r.dtype |
| r, w, k, v, a = [x.float() for x in (r, w, k, v, a)] |
| k_scaled = k * (D ** -0.5) |
| w_log = -0.6065306597633104 * torch.sigmoid(w) |
| a_sig = torch.sigmoid(a) |
| a_fla = -k_scaled |
| b_fla = sab_scale * k_scaled * a_sig |
| o, _ = _fla_chunk_rwkv7(r, k_scaled, v, a_fla, b_fla, log_w=w_log, scale=1.0, head_first=False) |
| return o.to(orig_dtype) |
|
|
| def _wkv7_scan_python(self, r, w, k, v, a, sab_scale): |
| B, T, H, D = r.shape |
| orig_dtype = r.dtype |
|
|
| r, w, k, v, a = [x.float() for x in (r, w, k, v, a)] |
| k = k * (D ** -0.5) |
| decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w)) |
| a = torch.sigmoid(a) |
|
|
| state = torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32) |
| outputs = [] |
|
|
| for t in range(T): |
| if t > 0 and t % 16 == 0: |
| state = state.detach() |
|
|
| kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t] |
|
|
| sa = torch.einsum('bhij,bhj->bhi', state, -kt) |
| sab = torch.einsum('bhi,bhj->bhij', sa, kt * at) |
| state = state * dt.unsqueeze(-2) + sab_scale * sab + torch.einsum('bhi,bhj->bhij', vt, kt) |
| state = state.clamp(-10.0, 10.0) |
|
|
| outputs.append(torch.einsum('bhij,bhj->bhi', state, rt)) |
|
|
| return torch.stack(outputs, dim=1).to(orig_dtype) |
|
|
| def _wkv7_scan(self, r, w, k, v, a, sab_scale): |
| if _TRITON_AVAILABLE and r.is_cuda: |
| B, T, H, D = r.shape |
| orig_dtype = r.dtype |
| r, w, k, v, a = [x.float() for x in (r, w, k, v, a)] |
| k = k * (D ** -0.5) |
| decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w)) |
| a = torch.sigmoid(a) |
| return _wkv7_scan_triton(r, decay, k, v, a, sab_scale).to(orig_dtype) |
| if _FLA_AVAILABLE and r.is_cuda: |
| return self._wkv7_scan_fla(r, w, k, v, a, sab_scale) |
| return self._wkv7_scan_python(r, w, k, v, a, sab_scale) |
|
|
| def forward(self, x, attention_mask=None, **kwargs): |
| B, T, C = x.shape |
| H, D = self.num_heads, self.head_size |
|
|
| mixed = self._token_shift(x) |
| r = self.W_r(mixed['r']).view(B, T, H, D) |
| w = self.W_w(mixed['w']).view(B, T, H, D) |
| k = self.W_k(mixed['k']).view(B, T, H, D) |
| v = self.W_v(mixed['v']).view(B, T, H, D) |
| a = self.W_a(mixed['a']).view(B, T, H, D) |
| g = torch.sigmoid(self.W_g(mixed['g'])) |
|
|
| sab_scale = torch.sigmoid(self.sab_gate) |
|
|
| out_fwd = self._wkv7_scan(r, w, k, v, a, sab_scale) |
| out_bwd = self._wkv7_scan( |
| r.flip(1), w.flip(1), k.flip(1), v.flip(1), a.flip(1), sab_scale |
| ).flip(1) |
|
|
| out = (out_fwd + out_bwd).reshape(B, T, C) * 0.5 |
| out = self.group_norm(out.transpose(1, 2)).transpose(1, 2) |
| out = self.W_o(out * g) |
|
|
| return out, None |
|
|
|
|
| def init_from_attention(birwkv, attn_module): |
| q_proj = k_proj = v_proj = o_proj = None |
|
|
| if hasattr(attn_module, 'Wqkv'): |
| fused = attn_module.Wqkv.weight.data |
| C = fused.shape[1] |
| q_proj, k_proj, v_proj = fused[:C], fused[C:2*C], fused[2*C:] |
| else: |
| for name in ['q_proj', 'query', 'W_q', 'wq']: |
| if hasattr(attn_module, name): |
| q_proj = getattr(attn_module, name).weight.data |
| break |
| for name in ['k_proj', 'key', 'W_k', 'wk']: |
| if hasattr(attn_module, name): |
| k_proj = getattr(attn_module, name).weight.data |
| break |
| for name in ['v_proj', 'value', 'W_v', 'wv']: |
| if hasattr(attn_module, name): |
| v_proj = getattr(attn_module, name).weight.data |
| break |
|
|
| for name in ['Wo', 'out_proj', 'o_proj', 'dense', 'W_o', 'wo']: |
| if hasattr(attn_module, name): |
| o_proj = getattr(attn_module, name).weight.data |
| break |
|
|
| transferred = [] |
| for src, dst, label in [ |
| (q_proj, birwkv.W_r, 'Q->R'), |
| (k_proj, birwkv.W_k, 'K->K'), |
| (v_proj, birwkv.W_v, 'V->V'), |
| (o_proj, birwkv.W_o, 'O->O'), |
| ]: |
| if src is not None: |
| dst.weight.data.copy_(src) |
| transferred.append(label) |
|
|
| return transferred |
|
|