import torch import torch.nn as nn import torch.nn.functional as F _TRITON_AVAILABLE = False try: import triton import triton.language as tl @triton.jit def _wkv7_fwd_kernel( R, K, V, DECAY, A, O, STATE_OUT, STATE_IN, sab_scale, T, stride_b, stride_t, stride_h, H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr, RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr, ): pid = tl.program_id(0) b_idx = pid // H h_idx = pid % H base = b_idx * stride_b + h_idx * stride_h di = tl.arange(0, BLOCK_D) dj = tl.arange(0, BLOCK_D) mask_i = di < D mask_j = dj < D if HAS_INIT_STATE: s_off = b_idx * (H * D * D) + h_idx * (D * D) state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :] state_mask = mask_i[:, None] & mask_j[None, :] state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32) else: state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32) for t in range(T): t_off = base + t * stride_t kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32) rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32) at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) sa = tl.sum(state * (-kt)[None, :], axis=1) ka = kt * at sab = sa[:, None] * ka[None, :] state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :] state = tl.minimum(tl.maximum(state, -10.0), 10.0) out_t = tl.sum(state * rt[None, :], axis=1) tl.store(O + t_off + di, out_t, mask=mask_i) if RETURN_STATE: s_off = b_idx * (H * D * D) + h_idx * (D * D) state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :] state_mask = mask_i[:, None] & mask_j[None, :] tl.store(state_ptrs, state, mask=state_mask) def _wkv7_scan_triton(r, decay, k, v, a, sab_scale): B, T, H, D = r.shape r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)] o = torch.empty_like(r) stride_b, stride_t, stride_h = T * H * D, H * D, D BLOCK_D = triton.next_power_of_2(D) _wkv7_fwd_kernel[(B * H,)]( r, k, v, decay, a, o, None, None, float(sab_scale), T, stride_b, stride_t, stride_h, H=H, D=D, BLOCK_D=BLOCK_D, RETURN_STATE=False, HAS_INIT_STATE=False, ) return o if torch.cuda.is_available(): _TRITON_AVAILABLE = True except Exception: pass _FLA_AVAILABLE = False try: import torch.distributed.tensor as _tdt if not hasattr(_tdt, 'Replicate'): try: from torch.distributed._tensor import Replicate as _R, Shard as _S _tdt.Replicate = _R; _tdt.Shard = _S except ImportError: pass if not hasattr(_tdt, 'Placement'): try: from torch.distributed._tensor.placement_types import Placement as _P _tdt.Placement = _P except ImportError: pass if not hasattr(_tdt, 'distribute_module'): _tdt.distribute_module = lambda *a, **kw: None from fla.ops.rwkv7 import chunk_rwkv7 as _fla_chunk_rwkv7 if torch.cuda.is_available(): _test_r = torch.randn(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16, requires_grad=True) _test_w = -torch.ones(1, 1, 2, 64, device='cuda', dtype=torch.bfloat16) _test_o, _ = _fla_chunk_rwkv7(_test_r, _test_w, _test_r, _test_r, _test_r, _test_r, head_first=False) _test_o.sum().backward() if not _test_r.grad.isnan().any(): _FLA_AVAILABLE = True del _test_r, _test_w, _test_o torch.cuda.empty_cache() else: _FLA_AVAILABLE = True except Exception: pass class BiRWKV7Layer(nn.Module): def __init__(self, hidden_size, num_heads): super().__init__() assert hidden_size % num_heads == 0 self.hidden_size = hidden_size self.num_heads = num_heads self.head_size = hidden_size // num_heads self.mu_r = nn.Parameter(torch.zeros(hidden_size)) self.mu_w = nn.Parameter(torch.zeros(hidden_size)) self.mu_k = nn.Parameter(torch.zeros(hidden_size)) self.mu_v = nn.Parameter(torch.zeros(hidden_size)) self.mu_a = nn.Parameter(torch.zeros(hidden_size)) self.mu_g = nn.Parameter(torch.zeros(hidden_size)) self.W_r = nn.Linear(hidden_size, hidden_size, bias=False) self.W_k = nn.Linear(hidden_size, hidden_size, bias=False) self.W_v = nn.Linear(hidden_size, hidden_size, bias=False) self.W_w = nn.Linear(hidden_size, hidden_size, bias=False) self.W_a = nn.Linear(hidden_size, hidden_size, bias=False) self.W_g = nn.Linear(hidden_size, hidden_size, bias=False) self.sab_gate = nn.Parameter(torch.tensor(-5.0)) self.group_norm = nn.GroupNorm(num_heads, hidden_size) self.W_o = nn.Linear(hidden_size, hidden_size, bias=False) nn.init.normal_(self.W_w.weight, std=0.01) nn.init.normal_(self.W_a.weight, std=0.01) nn.init.normal_(self.W_g.weight, std=0.02) def _token_shift(self, x): x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) def mix(mu): return x + (x_prev - x) * torch.sigmoid(mu) return { 'r': mix(self.mu_r), 'w': mix(self.mu_w), 'k': mix(self.mu_k), 'v': mix(self.mu_v), 'a': mix(self.mu_a), 'g': mix(self.mu_g), } def _wkv7_scan_fla(self, r, w, k, v, a, sab_scale): B, T, H, D = r.shape orig_dtype = r.dtype r, w, k, v, a = [x.float() for x in (r, w, k, v, a)] k_scaled = k * (D ** -0.5) w_log = -0.6065306597633104 * torch.sigmoid(w) a_sig = torch.sigmoid(a) a_fla = -k_scaled b_fla = sab_scale * k_scaled * a_sig o, _ = _fla_chunk_rwkv7(r, k_scaled, v, a_fla, b_fla, log_w=w_log, scale=1.0, head_first=False) return o.to(orig_dtype) def _wkv7_scan_python(self, r, w, k, v, a, sab_scale): B, T, H, D = r.shape orig_dtype = r.dtype r, w, k, v, a = [x.float() for x in (r, w, k, v, a)] k = k * (D ** -0.5) decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w)) a = torch.sigmoid(a) state = torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32) outputs = [] for t in range(T): if t > 0 and t % 16 == 0: state = state.detach() kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t] sa = torch.einsum('bhij,bhj->bhi', state, -kt) sab = torch.einsum('bhi,bhj->bhij', sa, kt * at) state = state * dt.unsqueeze(-2) + sab_scale * sab + torch.einsum('bhi,bhj->bhij', vt, kt) state = state.clamp(-10.0, 10.0) outputs.append(torch.einsum('bhij,bhj->bhi', state, rt)) return torch.stack(outputs, dim=1).to(orig_dtype) def _wkv7_scan(self, r, w, k, v, a, sab_scale): if _TRITON_AVAILABLE and r.is_cuda: B, T, H, D = r.shape orig_dtype = r.dtype r, w, k, v, a = [x.float() for x in (r, w, k, v, a)] k = k * (D ** -0.5) decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w)) a = torch.sigmoid(a) return _wkv7_scan_triton(r, decay, k, v, a, sab_scale).to(orig_dtype) if _FLA_AVAILABLE and r.is_cuda: return self._wkv7_scan_fla(r, w, k, v, a, sab_scale) return self._wkv7_scan_python(r, w, k, v, a, sab_scale) def forward(self, x, attention_mask=None, **kwargs): B, T, C = x.shape H, D = self.num_heads, self.head_size mixed = self._token_shift(x) r = self.W_r(mixed['r']).view(B, T, H, D) w = self.W_w(mixed['w']).view(B, T, H, D) k = self.W_k(mixed['k']).view(B, T, H, D) v = self.W_v(mixed['v']).view(B, T, H, D) a = self.W_a(mixed['a']).view(B, T, H, D) g = torch.sigmoid(self.W_g(mixed['g'])) sab_scale = torch.sigmoid(self.sab_gate) out_fwd = self._wkv7_scan(r, w, k, v, a, sab_scale) out_bwd = self._wkv7_scan( r.flip(1), w.flip(1), k.flip(1), v.flip(1), a.flip(1), sab_scale ).flip(1) out = (out_fwd + out_bwd).reshape(B, T, C) * 0.5 out = self.group_norm(out.transpose(1, 2)).transpose(1, 2) out = self.W_o(out * g) return out, None def init_from_attention(birwkv, attn_module): q_proj = k_proj = v_proj = o_proj = None if hasattr(attn_module, 'Wqkv'): fused = attn_module.Wqkv.weight.data C = fused.shape[1] q_proj, k_proj, v_proj = fused[:C], fused[C:2*C], fused[2*C:] else: for name in ['q_proj', 'query', 'W_q', 'wq']: if hasattr(attn_module, name): q_proj = getattr(attn_module, name).weight.data break for name in ['k_proj', 'key', 'W_k', 'wk']: if hasattr(attn_module, name): k_proj = getattr(attn_module, name).weight.data break for name in ['v_proj', 'value', 'W_v', 'wv']: if hasattr(attn_module, name): v_proj = getattr(attn_module, name).weight.data break for name in ['Wo', 'out_proj', 'o_proj', 'dense', 'W_o', 'wo']: if hasattr(attn_module, name): o_proj = getattr(attn_module, name).weight.data break transferred = [] for src, dst, label in [ (q_proj, birwkv.W_r, 'Q->R'), (k_proj, birwkv.W_k, 'K->K'), (v_proj, birwkv.W_v, 'V->V'), (o_proj, birwkv.W_o, 'O->O'), ]: if src is not None: dst.weight.data.copy_(src) transferred.append(label) return transferred