HARE / streaming.py
SixOpen's picture
Upload folder using huggingface_hub
f8ab83c verified
import torch
import torch.nn.functional as F
from birwkv7 import BiRWKV7Layer
def wkv7_forward_scan(r, w, k, v, a, sab_scale, init_state=None):
B, T, H, D = r.shape
r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
k = k * (D ** -0.5)
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
a = torch.sigmoid(a)
sab_s = float(sab_scale)
state = init_state.float().clone() if init_state is not None else \
torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
outputs = []
for t in range(T):
kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
sa = torch.einsum('bhij,bhj->bhi', state, -kt)
sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
state = state * dt.unsqueeze(-2) + sab_s * sab + \
torch.einsum('bhi,bhj->bhij', vt, kt)
state = state.clamp(-10.0, 10.0)
outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
return torch.stack(outputs, dim=1), state.detach()
class SpanEncoder:
def __init__(self, model, tokenizer, device, chunk_size=512):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.chunk_size = chunk_size
self.birwkv_layers = []
self.birwkv_ids = {}
for m in model.modules():
if isinstance(m, BiRWKV7Layer):
self.birwkv_ids[id(m)] = len(self.birwkv_layers)
self.birwkv_layers.append(m)
self._originals = {}
self._hooked = False
self._active_states = [None] * len(self.birwkv_layers)
self.span_data = {}
def _hook(self):
if self._hooked:
return
for layer in self.birwkv_layers:
self._originals[id(layer)] = layer.forward
layer.forward = self._make_fwd(layer)
self._hooked = True
def _unhook(self):
if not self._hooked:
return
for layer in self.birwkv_layers:
layer.forward = self._originals[id(layer)]
self._originals.clear()
self._hooked = False
def _make_fwd(self, layer):
enc = self
idx = self.birwkv_ids[id(layer)]
def fwd(x, attention_mask=None, **kwargs):
B, T, C_ = x.shape
H, D = layer.num_heads, layer.head_size
prev = enc._active_states[idx]
if prev is not None:
x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1)
else:
x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
def mix(mu):
return x + (x_prev - x) * torch.sigmoid(mu)
r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D)
w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D)
k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D)
v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D)
a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D)
g = torch.sigmoid(layer.W_g(mix(layer.mu_g)))
sab_scale = torch.sigmoid(layer.sab_gate)
init_st = prev['wkv_state'] if prev else None
try:
from birwkv7_triton import wkv7_scan_triton
r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float()
a_f = torch.sigmoid(a.float())
decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float()))
out_fwd, wkv_state = wkv7_scan_triton(
r_f, decay, k_f, v_f, a_f, sab_scale,
return_state=True, init_state=init_st)
out_bwd = wkv7_scan_triton(
r_f.flip(1), decay.flip(1), k_f.flip(1),
v_f.flip(1), a_f.flip(1), sab_scale,
return_state=False).flip(1)
except (ImportError, Exception):
out_fwd, wkv_state = wkv7_forward_scan(
r, w, k, v, a, sab_scale, init_st)
out_bwd = wkv7_forward_scan(
r.flip(1), w.flip(1), k.flip(1),
v.flip(1), a.flip(1), sab_scale, None)[0].flip(1)
enc._active_states[idx] = {
'wkv_state': wkv_state,
'last_x': x[:, -1:].detach().clone(),
}
out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_)
out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2)
out = layer.W_o(out * g)
return out, None
return fwd
@torch.no_grad()
def _forward_encode_raw(self, text, init_states=None, max_length=8192):
self._hook()
if init_states is not None:
self._active_states = [
{k: v.clone() for k, v in s.items()} if s else None
for s in init_states
]
else:
self._active_states = [None] * len(self.birwkv_layers)
enc = self.tokenizer(text, return_tensors='pt', truncation=True,
max_length=max_length)
ids = enc['input_ids'].to(self.device)
mask = enc['attention_mask'].to(self.device)
h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
content = h[0, 1:-1, :].cpu()
n_content = content.shape[0]
final_states = [
{k: v.clone() for k, v in s.items()} if s else None
for s in self._active_states
]
self._unhook()
return content, n_content, final_states
def _chunk_hidden(self, content, return_residual=False):
T = content.shape[0]
chunks = []
last_end = 0
for start in range(0, T, self.chunk_size):
end = min(start + self.chunk_size, T)
if end - start < 32:
break
emb = F.normalize(content[start:end].mean(0, keepdim=True),
p=2, dim=-1)
chunks.append(emb)
last_end = end
if not chunks and T > 0:
chunks.append(F.normalize(content.mean(0, keepdim=True),
p=2, dim=-1))
last_end = T
if return_residual:
residual = content[last_end:] if last_end < T else None
return chunks, residual
return chunks
@torch.no_grad()
def encode_query(self, query):
assert not self._hooked
enc = self.tokenizer(query, return_tensors='pt', truncation=True,
max_length=512)
ids = enc['input_ids'].to(self.device)
mask = enc['attention_mask'].to(self.device)
h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
m = mask.unsqueeze(-1).float()
emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9)
return F.normalize(emb, p=2, dim=-1).cpu()
def encode_span(self, text, key):
content, n_tok, states = self._forward_encode_raw(text)
chunks, residual = self._chunk_hidden(content, return_residual=True)
self.span_data[key] = {
'layer_states': states,
'chunk_embs': chunks,
'n_tokens': n_tok,
'residual_hidden': residual,
}
return n_tok
def extend_right(self, piece_text, old_key, new_key):
old = self.span_data.pop(old_key)
content, n_new, states = self._forward_encode_raw(
piece_text, init_states=old['layer_states'])
if old.get('residual_hidden') is not None:
content = torch.cat([old['residual_hidden'], content], dim=0)
new_chunks, residual = self._chunk_hidden(
content, return_residual=True)
self.span_data[new_key] = {
'layer_states': states,
'chunk_embs': old['chunk_embs'] + new_chunks,
'n_tokens': old['n_tokens'] + n_new,
'residual_hidden': residual,
}
return n_new