haoxiangsnr's picture
Add files using upload-large-folder tool
5e598cd verified
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Hangting Chen <465244243@qq.com>
#
"""Implement noncausally and causally masked linear attention."""
import torch
from torch.nn import Module
# from fast_transformers.causal_product import causal_dot_product
# def causal_linear(Q, K, V):
# # B T H C
# Q = Q.permute(0,2,1,3).contiguous()
# # B H T C
# K = K.permute(0,2,1,3).contiguous()
# V = V.permute(0,2,1,3).contiguous()
# V_new = causal_dot_product(Q, K, V)
# return V_new.permute(0,2,1,3).contiguous()
# def causal_linear(Q, K, V):
# # # B T H C
# # Q = Q.permute(0,2,1,3).contiguous()
# # # B H T C
# # K = K.permute(0,2,1,3).contiguous()
# # V = V.permute(0,2,1,3).contiguous()
# V_new = causal_dot_product(Q, K, V)
# return V_new.contiguous()
def causal_linear(Q, K, V):
# B T H C
KV = K.unsqueeze(-1) * V.unsqueeze(-2)
KV = KV.cumsum(1)
# B T H C C
V_new = (Q.unsqueeze(-1) * KV).sum(-2)
return V_new
class SharedLinearAttention(Module):
"""Implement causally masked attention using dot product of feature maps in
O(N D^2) complexity.
See fast_transformers.attention.linear_attention.LinearAttention for the
general concept of replacing the softmax with feature maps. In addition to
that, we also make use of the fact that causal masking is a triangular mask
which allows us to apply the masking and still compute the attention in O(N
D^2) complexity.
Arguments
---------
feature_map: callable, a callable that applies the feature map to the
last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
denominator (default: 1e-6)
"""
def __init__(self, query_dimensions, feature_map=None, eps=1e-6, event_dispatcher=""):
super(SharedLinearAttention, self).__init__()
self.feature_map = lambda x: torch.nn.functional.elu(x) + 1
self.eps = eps
def _make_sizes_compatible(self, Q, K):
"""Either slice or pad K in case that the sizes do not match between Q
and K."""
N, L, H, E = Q.shape
_, S, _, _ = K.shape
if L == S:
return Q, K
if L < S:
return Q, K[:, :L, :, :]
if L > S:
return Q, torch.cat([K, K.new_zeros(N, L - S, H, E)], dim=1)
def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths, causal):
# Apply the feature map to the queries and keys
Q = self.feature_map(queries)
K = self.feature_map(keys)
K = K * key_lengths.float_matrix[:, :, None, None]
# TODO: Shall we divide the Q and K with a relatively large number to
# avoid numerical instabilities in computing the denominator?
# We used to divide each with the max norm of all q and k but
# that seems relatively costly for a simple normalization.
if causal:
# Apply the key padding mask and make sure the attn_mask is a
# lower triangular causal mask
if not attn_mask.lower_triangular:
raise RuntimeError(("CausalLinearAttention only supports full " "lower triangular masks"))
# Compute the normalizers
Z = 1 / (torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
# Compute the unnormalized result
V = causal_linear(Q, K, values)
return V * Z[:, :, :, None]
else:
# Apply the key padding mask and make sure that the attn_mask is
# all_ones
if not attn_mask.all_ones:
raise RuntimeError(("LinearAttention does not support arbitrary " "attention masks"))
# Compute the KV matrix, namely the dot product of keys and values so
# that we never explicitly compute the attention matrix and thus
# decrease the complexity
KV = torch.einsum("nshd,nshm->nhmd", K, values)
# Compute the normalizer
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
# Finally compute and return the new values
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
return V.contiguous()
class BaseMask(object):
@property
def bool_matrix(self):
"""Return a bool (uint8) matrix with 1s to all places that should be
kept."""
raise NotImplementedError()
@property
def float_matrix(self):
"""Return the bool matrix as a float to be used as a multiplicative
mask for non softmax attentions."""
if not hasattr(self, "_float_matrix"):
with torch.no_grad():
self._float_matrix = self.bool_matrix.float()
return self._float_matrix
@property
def lengths(self):
"""If the matrix is of the following form
1 1 1 0 0 0 0
1 0 0 0 0 0 0
1 1 0 0 0 0 0
then return it as a vector of integers
3 1 2.
"""
if not hasattr(self, "_lengths"):
with torch.no_grad():
lengths = self.bool_matrix.long().sum(dim=-1)
# make sure that the mask starts with 1s and continues with 0s
# this should be changed to something more efficient, however,
# I chose simplicity over efficiency since the LengthMask class
# will be used anyway (and the result is cached)
m = self.bool_matrix.view(-1, self.shape[-1])
for i, l in enumerate(lengths.view(-1)):
if not torch.all(m[i, :l]):
raise ValueError("The mask is not a length mask")
self._lengths = lengths
return self._lengths
@property
def shape(self):
"""Return the shape of the boolean mask."""
return self.bool_matrix.shape
@property
def additive_matrix(self):
"""Return a float matrix to be added to an attention matrix before
softmax."""
if not hasattr(self, "_additive_matrix"):
with torch.no_grad():
self._additive_matrix = torch.log(self.bool_matrix.float())
return self._additive_matrix
@property
def additive_matrix_finite(self):
"""Same as additive_matrix but with -1e24 instead of infinity."""
if not hasattr(self, "_additive_matrix_finite"):
with torch.no_grad():
self._additive_matrix_finite = (~self.bool_matrix).float() * (-1e24)
return self._additive_matrix_finite
@property
def all_ones(self):
"""Return true if the mask is all ones."""
if not hasattr(self, "_all_ones"):
with torch.no_grad():
self._all_ones = torch.all(self.bool_matrix)
return self._all_ones
@property
def lower_triangular(self):
"""Return true if the attention is a triangular causal mask."""
if not hasattr(self, "_lower_triangular"):
self._lower_triangular = False
with torch.no_grad():
try:
lengths = self.lengths
if len(lengths.shape) == 1:
target = torch.arange(1, len(lengths) + 1, device=lengths.device)
self._lower_triangular = torch.all(lengths == target)
except ValueError:
pass
return self._lower_triangular
class FullMask(BaseMask):
"""Thin wrapper over a pytorch tensor that provides the BaseMask
interface.
The arguments can be given both by keyword arguments and positional
arguments. To imitate function overloading, the constructor checks the type
of the first argument and if it is a tensor it treats it as the mask.
otherwise it assumes that it was the N argument.
Arguments
---------
mask: The mask as a PyTorch tensor.
N: The rows of the all True mask to be created if the mask argument is
not provided.
M: The columns of the all True mask to be created if the mask argument
is not provided. If N is given M defaults to N.
device: The device to create the mask in (defaults to cpu)
"""
def __init__(self, mask=None, N=None, M=None, device="cpu"):
# mask is a tensor so we ignore N and M
if mask is not None and isinstance(mask, torch.Tensor):
if mask.dtype != torch.bool:
raise ValueError("FullMask expects the mask to be bool")
with torch.no_grad():
self._mask = mask.clone()
return
# mask is an integer, N is an integer and M is None so assume they were
# passed as N, M
if mask is not None and M is None and isinstance(mask, int):
M = N
N = mask
if N is not None:
M = M or N
with torch.no_grad():
self._mask = torch.ones(N, M, dtype=torch.bool, device=device)
self._all_ones = True
return
raise ValueError("Either mask or N should be provided")
@property
def bool_matrix(self):
return self._mask
class LengthMask(BaseMask):
"""Provide a BaseMask interface for lengths. Mostly to be used with
sequences of different lengths.
Arguments
---------
lengths: The lengths as a PyTorch long tensor
max_len: The maximum length for the mask (defaults to lengths.max())
device: The device to be used for creating the masks (defaults to
lengths.device)
"""
def __init__(self, lengths, max_len=None, device=None):
self._device = device or lengths.device
with torch.no_grad():
self._lengths = lengths.clone().to(self._device)
self._max_len = max_len or self._lengths.max()
self._bool_matrix = None
self._all_ones = torch.all(self._lengths == self._max_len).item()
@property
def bool_matrix(self):
if self._bool_matrix is None:
with torch.no_grad():
indices = torch.arange(self._max_len, device=self._device)
self._bool_matrix = indices.view(1, -1) < self._lengths.view(-1, 1)
return self._bool_matrix
class TriangularCausalMask(LengthMask):
"""A square matrix with everything masked out above the diagonal.
Arguments
---------
N: The size of the matrix
device: The device to create the mask in (defaults to cpu)
"""
def __init__(self, N, device="cpu"):
lengths = torch.arange(1, N + 1, device=device)
super(TriangularCausalMask, self).__init__(lengths, N, device)
self._lower_triangular = True
if __name__ == "__main__":
from llmzen.models.bsrnn.modeling_fast_attention import (
SharedLinearAttention as SLA1,
)
self_attn1 = SLA1(10)
self_attn2 = SharedLinearAttention(10)
q, k, v = torch.rand(2, 100, 4, 10 * 3).chunk(3, -1) # B T H C
m1, m2, m3 = (
FullMask(q.shape[1], k.shape[1], device=q.device),
FullMask(q.shape[0], q.shape[1], device=q.device),
FullMask(q.shape[0], q.shape[1], device=q.device),
)
out1 = self_attn1(q, k, v, m1, m2, m3, causal=False).view(q.shape[0], q.shape[1], -1)
out2 = self_attn2(q, k, v, m1, m2, m3, causal=False).view(q.shape[0], q.shape[1], -1)
print(out1.shape, out2.shape)
print((out1 - out2).abs().sum())
assert (out1 - out2).abs().sum() < 1e-3
q, k, v = torch.rand(2, 100, 4, 10 * 3).chunk(3, -1) # B T H C
m1, m2, m3 = (
TriangularCausalMask(q.shape[1], device=q.device),
FullMask(q.shape[0], q.shape[1], device=q.device),
FullMask(q.shape[0], q.shape[1], device=q.device),
)
out1 = self_attn1(q, k, v, m1, m2, m3, causal=True).view(q.shape[0], q.shape[1], -1)
out2 = self_attn2(q, k, v, m1, m2, m3, causal=True).view(q.shape[0], q.shape[1], -1)
print(out1.shape, out2.shape)
print((out1 - out2).abs().sum())
assert (out1 - out2).abs().sum() < 1e-3