import torch import torch.nn.functional as F from birwkv7 import BiRWKV7Layer def wkv7_forward_scan(r, w, k, v, a, sab_scale, init_state=None): B, T, H, D = r.shape r, w, k, v, a = [x.float() for x in (r, w, k, v, a)] k = k * (D ** -0.5) decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w)) a = torch.sigmoid(a) sab_s = float(sab_scale) state = init_state.float().clone() if init_state is not None else \ torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32) outputs = [] for t in range(T): kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t] sa = torch.einsum('bhij,bhj->bhi', state, -kt) sab = torch.einsum('bhi,bhj->bhij', sa, kt * at) state = state * dt.unsqueeze(-2) + sab_s * sab + \ torch.einsum('bhi,bhj->bhij', vt, kt) state = state.clamp(-10.0, 10.0) outputs.append(torch.einsum('bhij,bhj->bhi', state, rt)) return torch.stack(outputs, dim=1), state.detach() class SpanEncoder: def __init__(self, model, tokenizer, device, chunk_size=512): self.model = model self.tokenizer = tokenizer self.device = device self.chunk_size = chunk_size self.birwkv_layers = [] self.birwkv_ids = {} for m in model.modules(): if isinstance(m, BiRWKV7Layer): self.birwkv_ids[id(m)] = len(self.birwkv_layers) self.birwkv_layers.append(m) self._originals = {} self._hooked = False self._active_states = [None] * len(self.birwkv_layers) self.span_data = {} def _hook(self): if self._hooked: return for layer in self.birwkv_layers: self._originals[id(layer)] = layer.forward layer.forward = self._make_fwd(layer) self._hooked = True def _unhook(self): if not self._hooked: return for layer in self.birwkv_layers: layer.forward = self._originals[id(layer)] self._originals.clear() self._hooked = False def _make_fwd(self, layer): enc = self idx = self.birwkv_ids[id(layer)] def fwd(x, attention_mask=None, **kwargs): B, T, C_ = x.shape H, D = layer.num_heads, layer.head_size prev = enc._active_states[idx] if prev is not None: x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1) else: x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) def mix(mu): return x + (x_prev - x) * torch.sigmoid(mu) r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D) w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D) k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D) v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D) a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D) g = torch.sigmoid(layer.W_g(mix(layer.mu_g))) sab_scale = torch.sigmoid(layer.sab_gate) init_st = prev['wkv_state'] if prev else None try: from birwkv7_triton import wkv7_scan_triton r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float() a_f = torch.sigmoid(a.float()) decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float())) out_fwd, wkv_state = wkv7_scan_triton( r_f, decay, k_f, v_f, a_f, sab_scale, return_state=True, init_state=init_st) out_bwd = wkv7_scan_triton( r_f.flip(1), decay.flip(1), k_f.flip(1), v_f.flip(1), a_f.flip(1), sab_scale, return_state=False).flip(1) except (ImportError, Exception): out_fwd, wkv_state = wkv7_forward_scan( r, w, k, v, a, sab_scale, init_st) out_bwd = wkv7_forward_scan( r.flip(1), w.flip(1), k.flip(1), v.flip(1), a.flip(1), sab_scale, None)[0].flip(1) enc._active_states[idx] = { 'wkv_state': wkv_state, 'last_x': x[:, -1:].detach().clone(), } out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_) out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2) out = layer.W_o(out * g) return out, None return fwd @torch.no_grad() def _forward_encode_raw(self, text, init_states=None, max_length=8192): self._hook() if init_states is not None: self._active_states = [ {k: v.clone() for k, v in s.items()} if s else None for s in init_states ] else: self._active_states = [None] * len(self.birwkv_layers) enc = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length) ids = enc['input_ids'].to(self.device) mask = enc['attention_mask'].to(self.device) h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state content = h[0, 1:-1, :].cpu() n_content = content.shape[0] final_states = [ {k: v.clone() for k, v in s.items()} if s else None for s in self._active_states ] self._unhook() return content, n_content, final_states def _chunk_hidden(self, content, return_residual=False): T = content.shape[0] chunks = [] last_end = 0 for start in range(0, T, self.chunk_size): end = min(start + self.chunk_size, T) if end - start < 32: break emb = F.normalize(content[start:end].mean(0, keepdim=True), p=2, dim=-1) chunks.append(emb) last_end = end if not chunks and T > 0: chunks.append(F.normalize(content.mean(0, keepdim=True), p=2, dim=-1)) last_end = T if return_residual: residual = content[last_end:] if last_end < T else None return chunks, residual return chunks @torch.no_grad() def encode_query(self, query): assert not self._hooked enc = self.tokenizer(query, return_tensors='pt', truncation=True, max_length=512) ids = enc['input_ids'].to(self.device) mask = enc['attention_mask'].to(self.device) h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state m = mask.unsqueeze(-1).float() emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9) return F.normalize(emb, p=2, dim=-1).cpu() def encode_span(self, text, key): content, n_tok, states = self._forward_encode_raw(text) chunks, residual = self._chunk_hidden(content, return_residual=True) self.span_data[key] = { 'layer_states': states, 'chunk_embs': chunks, 'n_tokens': n_tok, 'residual_hidden': residual, } return n_tok def extend_right(self, piece_text, old_key, new_key): old = self.span_data.pop(old_key) content, n_new, states = self._forward_encode_raw( piece_text, init_states=old['layer_states']) if old.get('residual_hidden') is not None: content = torch.cat([old['residual_hidden'], content], dim=0) new_chunks, residual = self._chunk_hidden( content, return_residual=True) self.span_data[new_key] = { 'layer_states': states, 'chunk_embs': old['chunk_embs'] + new_chunks, 'n_tokens': old['n_tokens'] + n_new, 'residual_hidden': residual, } return n_new