| |
|
|
| from typing import Optional |
|
|
| import torch |
| from einops import repeat |
|
|
|
|
| def naive_recurrent_abc( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| s: torch.Tensor, |
| g: Optional[torch.Tensor] = None, |
| scale: Optional[int] = None, |
| initial_state: Optional[torch.Tensor] = None, |
| output_final_state: Optional[bool] = False |
| ) -> torch.Tensor: |
| dtype = q.dtype |
|
|
| NG = q.shape[1]//k.shape[1] |
| |
| if g is None: |
| z = s.float().logcumsumexp(2) |
| g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z |
| s = torch.exp(s - z) |
| q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) |
| k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) |
| if initial_state is not None: |
| initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) |
|
|
| B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] |
|
|
| hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) |
| ok = torch.zeros_like(s) |
|
|
| if scale is None: |
| scale = q.shape[-1] ** -0.5 |
|
|
| final_state = None |
| if initial_state is not None: |
| hk += initial_state[0] |
|
|
| for i in range(T): |
| q_i = q[:, :, i] * scale |
| k_i = k[:, :, i] |
| v_i = s[:, :, i] |
| g_i = g[:, :, i].exp() |
| hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] |
| ok[:, :, i] = (q_i[..., None] * hk).sum(-2) |
|
|
| qv = ok.softmax(-1) |
| hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) |
| ov = torch.zeros_like(v) |
| if initial_state is not None: |
| hv += initial_state[1] |
|
|
| for i in range(T): |
| q_i = qv[:, :, i] |
| k_i = s[:, :, i] |
| v_i = v[:, :, i] |
| g_i = g[:, :, i].exp() |
| hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] |
| ov[:, :, i] = (q_i[..., None] * hv).sum(-2) |
|
|
| if output_final_state: |
| final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) |
| return ov.to(dtype), final_state |
|
|
|
|
| def naive_cumsum_abc( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| s: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. |
| This is just for demonstration purposes, with no numerical stabilities guaranteed. |
| """ |
|
|
| dtype = q.dtype |
| q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) |
|
|
| scale = q.shape[-1] ** -0.5 |
| |
| s = (s - s.max(2, True)[0]).exp() |
| z = s.cumsum(2) |
| |
| K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) |
| V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) |
| |
| p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) |
| |
| o = torch.einsum('...m,...md->...d', p, V) |
| return o.to(dtype), None |
|
|