| | import torch
|
| | import torch.nn.functional as F
|
| |
|
| | from birwkv7 import BiRWKV7Layer
|
| |
|
| |
|
| | def wkv7_forward_scan(r, w, k, v, a, sab_scale, init_state=None):
|
| | B, T, H, D = r.shape
|
| | r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
|
| | k = k * (D ** -0.5)
|
| | decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
|
| | a = torch.sigmoid(a)
|
| | sab_s = float(sab_scale)
|
| | state = init_state.float().clone() if init_state is not None else \
|
| | torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
|
| | outputs = []
|
| | for t in range(T):
|
| | kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
|
| | sa = torch.einsum('bhij,bhj->bhi', state, -kt)
|
| | sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
|
| | state = state * dt.unsqueeze(-2) + sab_s * sab + \
|
| | torch.einsum('bhi,bhj->bhij', vt, kt)
|
| | state = state.clamp(-10.0, 10.0)
|
| | outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
|
| | return torch.stack(outputs, dim=1), state.detach()
|
| |
|
| |
|
| | class SpanEncoder:
|
| |
|
| | def __init__(self, model, tokenizer, device, chunk_size=512):
|
| | self.model = model
|
| | self.tokenizer = tokenizer
|
| | self.device = device
|
| | self.chunk_size = chunk_size
|
| |
|
| | self.birwkv_layers = []
|
| | self.birwkv_ids = {}
|
| | for m in model.modules():
|
| | if isinstance(m, BiRWKV7Layer):
|
| | self.birwkv_ids[id(m)] = len(self.birwkv_layers)
|
| | self.birwkv_layers.append(m)
|
| |
|
| | self._originals = {}
|
| | self._hooked = False
|
| | self._active_states = [None] * len(self.birwkv_layers)
|
| | self.span_data = {}
|
| |
|
| | def _hook(self):
|
| | if self._hooked:
|
| | return
|
| | for layer in self.birwkv_layers:
|
| | self._originals[id(layer)] = layer.forward
|
| | layer.forward = self._make_fwd(layer)
|
| | self._hooked = True
|
| |
|
| | def _unhook(self):
|
| | if not self._hooked:
|
| | return
|
| | for layer in self.birwkv_layers:
|
| | layer.forward = self._originals[id(layer)]
|
| | self._originals.clear()
|
| | self._hooked = False
|
| |
|
| | def _make_fwd(self, layer):
|
| | enc = self
|
| | idx = self.birwkv_ids[id(layer)]
|
| |
|
| | def fwd(x, attention_mask=None, **kwargs):
|
| | B, T, C_ = x.shape
|
| | H, D = layer.num_heads, layer.head_size
|
| | prev = enc._active_states[idx]
|
| | if prev is not None:
|
| | x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1)
|
| | else:
|
| | x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
|
| |
|
| | def mix(mu):
|
| | return x + (x_prev - x) * torch.sigmoid(mu)
|
| |
|
| | r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D)
|
| | w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D)
|
| | k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D)
|
| | v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D)
|
| | a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D)
|
| | g = torch.sigmoid(layer.W_g(mix(layer.mu_g)))
|
| | sab_scale = torch.sigmoid(layer.sab_gate)
|
| | init_st = prev['wkv_state'] if prev else None
|
| |
|
| | try:
|
| | from birwkv7_triton import wkv7_scan_triton
|
| | r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float()
|
| | a_f = torch.sigmoid(a.float())
|
| | decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float()))
|
| | out_fwd, wkv_state = wkv7_scan_triton(
|
| | r_f, decay, k_f, v_f, a_f, sab_scale,
|
| | return_state=True, init_state=init_st)
|
| | out_bwd = wkv7_scan_triton(
|
| | r_f.flip(1), decay.flip(1), k_f.flip(1),
|
| | v_f.flip(1), a_f.flip(1), sab_scale,
|
| | return_state=False).flip(1)
|
| | except (ImportError, Exception):
|
| | out_fwd, wkv_state = wkv7_forward_scan(
|
| | r, w, k, v, a, sab_scale, init_st)
|
| | out_bwd = wkv7_forward_scan(
|
| | r.flip(1), w.flip(1), k.flip(1),
|
| | v.flip(1), a.flip(1), sab_scale, None)[0].flip(1)
|
| | enc._active_states[idx] = {
|
| | 'wkv_state': wkv_state,
|
| | 'last_x': x[:, -1:].detach().clone(),
|
| | }
|
| | out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_)
|
| | out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2)
|
| | out = layer.W_o(out * g)
|
| | return out, None
|
| | return fwd
|
| |
|
| | @torch.no_grad()
|
| | def _forward_encode_raw(self, text, init_states=None, max_length=8192):
|
| | self._hook()
|
| | if init_states is not None:
|
| | self._active_states = [
|
| | {k: v.clone() for k, v in s.items()} if s else None
|
| | for s in init_states
|
| | ]
|
| | else:
|
| | self._active_states = [None] * len(self.birwkv_layers)
|
| |
|
| | enc = self.tokenizer(text, return_tensors='pt', truncation=True,
|
| | max_length=max_length)
|
| | ids = enc['input_ids'].to(self.device)
|
| | mask = enc['attention_mask'].to(self.device)
|
| |
|
| | h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| | content = h[0, 1:-1, :].cpu()
|
| | n_content = content.shape[0]
|
| |
|
| | final_states = [
|
| | {k: v.clone() for k, v in s.items()} if s else None
|
| | for s in self._active_states
|
| | ]
|
| | self._unhook()
|
| | return content, n_content, final_states
|
| |
|
| | def _chunk_hidden(self, content, return_residual=False):
|
| | T = content.shape[0]
|
| | chunks = []
|
| | last_end = 0
|
| | for start in range(0, T, self.chunk_size):
|
| | end = min(start + self.chunk_size, T)
|
| | if end - start < 32:
|
| | break
|
| | emb = F.normalize(content[start:end].mean(0, keepdim=True),
|
| | p=2, dim=-1)
|
| | chunks.append(emb)
|
| | last_end = end
|
| | if not chunks and T > 0:
|
| | chunks.append(F.normalize(content.mean(0, keepdim=True),
|
| | p=2, dim=-1))
|
| | last_end = T
|
| | if return_residual:
|
| | residual = content[last_end:] if last_end < T else None
|
| | return chunks, residual
|
| | return chunks
|
| |
|
| | @torch.no_grad()
|
| | def encode_query(self, query):
|
| | assert not self._hooked
|
| | enc = self.tokenizer(query, return_tensors='pt', truncation=True,
|
| | max_length=512)
|
| | ids = enc['input_ids'].to(self.device)
|
| | mask = enc['attention_mask'].to(self.device)
|
| | h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| | m = mask.unsqueeze(-1).float()
|
| | emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9)
|
| | return F.normalize(emb, p=2, dim=-1).cpu()
|
| |
|
| | def encode_span(self, text, key):
|
| | content, n_tok, states = self._forward_encode_raw(text)
|
| | chunks, residual = self._chunk_hidden(content, return_residual=True)
|
| | self.span_data[key] = {
|
| | 'layer_states': states,
|
| | 'chunk_embs': chunks,
|
| | 'n_tokens': n_tok,
|
| | 'residual_hidden': residual,
|
| | }
|
| | return n_tok
|
| |
|
| | def extend_right(self, piece_text, old_key, new_key):
|
| | old = self.span_data.pop(old_key)
|
| | content, n_new, states = self._forward_encode_raw(
|
| | piece_text, init_states=old['layer_states'])
|
| | if old.get('residual_hidden') is not None:
|
| | content = torch.cat([old['residual_hidden'], content], dim=0)
|
| | new_chunks, residual = self._chunk_hidden(
|
| | content, return_residual=True)
|
| | self.span_data[new_key] = {
|
| | 'layer_states': states,
|
| | 'chunk_embs': old['chunk_embs'] + new_chunks,
|
| | 'n_tokens': old['n_tokens'] + n_new,
|
| | 'residual_hidden': residual,
|
| | }
|
| | return n_new
|
| |
|