| |
| |
| |
| |
|
|
| """Implement noncausally and causally masked linear attention.""" |
|
|
| import torch |
| from torch.nn import Module |
|
|
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def causal_linear(Q, K, V): |
| |
| KV = K.unsqueeze(-1) * V.unsqueeze(-2) |
| KV = KV.cumsum(1) |
| |
| V_new = (Q.unsqueeze(-1) * KV).sum(-2) |
| return V_new |
|
|
|
|
| class SharedLinearAttention(Module): |
| """Implement causally masked attention using dot product of feature maps in |
| O(N D^2) complexity. |
| See fast_transformers.attention.linear_attention.LinearAttention for the |
| general concept of replacing the softmax with feature maps. In addition to |
| that, we also make use of the fact that causal masking is a triangular mask |
| which allows us to apply the masking and still compute the attention in O(N |
| D^2) complexity. |
| Arguments |
| --------- |
| feature_map: callable, a callable that applies the feature map to the |
| last dimension of a tensor (default: elu(x)+1) |
| eps: float, a small number to ensure the numerical stability of the |
| denominator (default: 1e-6) |
| """ |
|
|
| def __init__(self, query_dimensions, feature_map=None, eps=1e-6, event_dispatcher=""): |
| super(SharedLinearAttention, self).__init__() |
| self.feature_map = lambda x: torch.nn.functional.elu(x) + 1 |
| self.eps = eps |
|
|
| def _make_sizes_compatible(self, Q, K): |
| """Either slice or pad K in case that the sizes do not match between Q |
| and K.""" |
| N, L, H, E = Q.shape |
| _, S, _, _ = K.shape |
| if L == S: |
| return Q, K |
|
|
| if L < S: |
| return Q, K[:, :L, :, :] |
|
|
| if L > S: |
| return Q, torch.cat([K, K.new_zeros(N, L - S, H, E)], dim=1) |
|
|
| def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths, causal): |
| |
| Q = self.feature_map(queries) |
| K = self.feature_map(keys) |
| K = K * key_lengths.float_matrix[:, :, None, None] |
|
|
| |
| |
| |
| |
| if causal: |
| |
| |
| if not attn_mask.lower_triangular: |
| raise RuntimeError(("CausalLinearAttention only supports full " "lower triangular masks")) |
| |
| Z = 1 / (torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps) |
|
|
| |
| V = causal_linear(Q, K, values) |
| return V * Z[:, :, :, None] |
| else: |
| |
| |
| if not attn_mask.all_ones: |
| raise RuntimeError(("LinearAttention does not support arbitrary " "attention masks")) |
|
|
| |
| |
| |
| KV = torch.einsum("nshd,nshm->nhmd", K, values) |
|
|
| |
| Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) |
|
|
| |
| V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) |
| return V.contiguous() |
|
|
|
|
| class BaseMask(object): |
| @property |
| def bool_matrix(self): |
| """Return a bool (uint8) matrix with 1s to all places that should be |
| kept.""" |
| raise NotImplementedError() |
|
|
| @property |
| def float_matrix(self): |
| """Return the bool matrix as a float to be used as a multiplicative |
| mask for non softmax attentions.""" |
| if not hasattr(self, "_float_matrix"): |
| with torch.no_grad(): |
| self._float_matrix = self.bool_matrix.float() |
| return self._float_matrix |
|
|
| @property |
| def lengths(self): |
| """If the matrix is of the following form |
| 1 1 1 0 0 0 0 |
| 1 0 0 0 0 0 0 |
| 1 1 0 0 0 0 0 |
| then return it as a vector of integers |
| 3 1 2. |
| """ |
| if not hasattr(self, "_lengths"): |
| with torch.no_grad(): |
| lengths = self.bool_matrix.long().sum(dim=-1) |
| |
| |
| |
| |
| m = self.bool_matrix.view(-1, self.shape[-1]) |
| for i, l in enumerate(lengths.view(-1)): |
| if not torch.all(m[i, :l]): |
| raise ValueError("The mask is not a length mask") |
| self._lengths = lengths |
| return self._lengths |
|
|
| @property |
| def shape(self): |
| """Return the shape of the boolean mask.""" |
| return self.bool_matrix.shape |
|
|
| @property |
| def additive_matrix(self): |
| """Return a float matrix to be added to an attention matrix before |
| softmax.""" |
| if not hasattr(self, "_additive_matrix"): |
| with torch.no_grad(): |
| self._additive_matrix = torch.log(self.bool_matrix.float()) |
| return self._additive_matrix |
|
|
| @property |
| def additive_matrix_finite(self): |
| """Same as additive_matrix but with -1e24 instead of infinity.""" |
| if not hasattr(self, "_additive_matrix_finite"): |
| with torch.no_grad(): |
| self._additive_matrix_finite = (~self.bool_matrix).float() * (-1e24) |
| return self._additive_matrix_finite |
|
|
| @property |
| def all_ones(self): |
| """Return true if the mask is all ones.""" |
| if not hasattr(self, "_all_ones"): |
| with torch.no_grad(): |
| self._all_ones = torch.all(self.bool_matrix) |
| return self._all_ones |
|
|
| @property |
| def lower_triangular(self): |
| """Return true if the attention is a triangular causal mask.""" |
| if not hasattr(self, "_lower_triangular"): |
| self._lower_triangular = False |
| with torch.no_grad(): |
| try: |
| lengths = self.lengths |
| if len(lengths.shape) == 1: |
| target = torch.arange(1, len(lengths) + 1, device=lengths.device) |
| self._lower_triangular = torch.all(lengths == target) |
| except ValueError: |
| pass |
| return self._lower_triangular |
|
|
|
|
| class FullMask(BaseMask): |
| """Thin wrapper over a pytorch tensor that provides the BaseMask |
| interface. |
| The arguments can be given both by keyword arguments and positional |
| arguments. To imitate function overloading, the constructor checks the type |
| of the first argument and if it is a tensor it treats it as the mask. |
| otherwise it assumes that it was the N argument. |
| Arguments |
| --------- |
| mask: The mask as a PyTorch tensor. |
| N: The rows of the all True mask to be created if the mask argument is |
| not provided. |
| M: The columns of the all True mask to be created if the mask argument |
| is not provided. If N is given M defaults to N. |
| device: The device to create the mask in (defaults to cpu) |
| """ |
|
|
| def __init__(self, mask=None, N=None, M=None, device="cpu"): |
| |
| if mask is not None and isinstance(mask, torch.Tensor): |
| if mask.dtype != torch.bool: |
| raise ValueError("FullMask expects the mask to be bool") |
| with torch.no_grad(): |
| self._mask = mask.clone() |
| return |
|
|
| |
| |
| if mask is not None and M is None and isinstance(mask, int): |
| M = N |
| N = mask |
|
|
| if N is not None: |
| M = M or N |
| with torch.no_grad(): |
| self._mask = torch.ones(N, M, dtype=torch.bool, device=device) |
| self._all_ones = True |
| return |
|
|
| raise ValueError("Either mask or N should be provided") |
|
|
| @property |
| def bool_matrix(self): |
| return self._mask |
|
|
|
|
| class LengthMask(BaseMask): |
| """Provide a BaseMask interface for lengths. Mostly to be used with |
| sequences of different lengths. |
| Arguments |
| --------- |
| lengths: The lengths as a PyTorch long tensor |
| max_len: The maximum length for the mask (defaults to lengths.max()) |
| device: The device to be used for creating the masks (defaults to |
| lengths.device) |
| """ |
|
|
| def __init__(self, lengths, max_len=None, device=None): |
| self._device = device or lengths.device |
| with torch.no_grad(): |
| self._lengths = lengths.clone().to(self._device) |
| self._max_len = max_len or self._lengths.max() |
|
|
| self._bool_matrix = None |
| self._all_ones = torch.all(self._lengths == self._max_len).item() |
|
|
| @property |
| def bool_matrix(self): |
| if self._bool_matrix is None: |
| with torch.no_grad(): |
| indices = torch.arange(self._max_len, device=self._device) |
| self._bool_matrix = indices.view(1, -1) < self._lengths.view(-1, 1) |
| return self._bool_matrix |
|
|
|
|
| class TriangularCausalMask(LengthMask): |
| """A square matrix with everything masked out above the diagonal. |
| Arguments |
| --------- |
| N: The size of the matrix |
| device: The device to create the mask in (defaults to cpu) |
| """ |
|
|
| def __init__(self, N, device="cpu"): |
| lengths = torch.arange(1, N + 1, device=device) |
| super(TriangularCausalMask, self).__init__(lengths, N, device) |
| self._lower_triangular = True |
|
|
|
|
| if __name__ == "__main__": |
| from llmzen.models.bsrnn.modeling_fast_attention import ( |
| SharedLinearAttention as SLA1, |
| ) |
|
|
| self_attn1 = SLA1(10) |
| self_attn2 = SharedLinearAttention(10) |
| q, k, v = torch.rand(2, 100, 4, 10 * 3).chunk(3, -1) |
| m1, m2, m3 = ( |
| FullMask(q.shape[1], k.shape[1], device=q.device), |
| FullMask(q.shape[0], q.shape[1], device=q.device), |
| FullMask(q.shape[0], q.shape[1], device=q.device), |
| ) |
| out1 = self_attn1(q, k, v, m1, m2, m3, causal=False).view(q.shape[0], q.shape[1], -1) |
| out2 = self_attn2(q, k, v, m1, m2, m3, causal=False).view(q.shape[0], q.shape[1], -1) |
| print(out1.shape, out2.shape) |
| print((out1 - out2).abs().sum()) |
| assert (out1 - out2).abs().sum() < 1e-3 |
|
|
| q, k, v = torch.rand(2, 100, 4, 10 * 3).chunk(3, -1) |
| m1, m2, m3 = ( |
| TriangularCausalMask(q.shape[1], device=q.device), |
| FullMask(q.shape[0], q.shape[1], device=q.device), |
| FullMask(q.shape[0], q.shape[1], device=q.device), |
| ) |
| out1 = self_attn1(q, k, v, m1, m2, m3, causal=True).view(q.shape[0], q.shape[1], -1) |
| out2 = self_attn2(q, k, v, m1, m2, m3, causal=True).view(q.shape[0], q.shape[1], -1) |
| print(out1.shape, out2.shape) |
| print((out1 - out2).abs().sum()) |
| assert (out1 - out2).abs().sum() < 1e-3 |
|
|