|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
from .multihead_attention import MultiheadAttention |
|
|
|
|
|
|
|
|
class SparseMultiheadAttention(MultiheadAttention): |
|
|
""" Sparse Multi-Headed Attention. |
|
|
|
|
|
"Generating Long Sequences with Sparse Transformers". Implements |
|
|
fixed factorized self attention, where l=stride and c=expressivity. |
|
|
A(1) includes all words in the stride window and A(2) takes a summary of c |
|
|
words from the end of each stride window. |
|
|
If is_bidirectional=False, we do not include any words past the current word, |
|
|
as in the paper. |
|
|
""" |
|
|
|
|
|
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, |
|
|
add_bias_kv=False, add_zero_attn=False, self_attention=False, |
|
|
encoder_decoder_attention=False, stride=32, expressivity=8, is_bidirectional=True): |
|
|
|
|
|
super().__init__( |
|
|
embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv, |
|
|
add_zero_attn, self_attention, encoder_decoder_attention |
|
|
) |
|
|
|
|
|
self.is_bidirectional = is_bidirectional |
|
|
self.stride = stride |
|
|
self.expressivity = expressivity |
|
|
assert(self.stride > 0 and self.stride >= self.expressivity) |
|
|
|
|
|
|
|
|
def compute_checkpoint(self, word_index): |
|
|
if word_index % self.stride == 0 and word_index != 0: |
|
|
checkpoint_index = word_index - self.expressivity |
|
|
else: |
|
|
checkpoint_index = ( |
|
|
math.floor(word_index / self.stride) * self.stride |
|
|
+ self.stride - self.expressivity |
|
|
) |
|
|
return checkpoint_index |
|
|
|
|
|
|
|
|
def compute_subset_summaries(self, absolute_max): |
|
|
checkpoint_index = self.compute_checkpoint(0) |
|
|
subset_two = set() |
|
|
while checkpoint_index <= absolute_max-1: |
|
|
summary = set(range(checkpoint_index, min( |
|
|
checkpoint_index+self.expressivity+1, absolute_max) |
|
|
)) |
|
|
subset_two = subset_two.union(summary) |
|
|
checkpoint_index = self.compute_checkpoint(checkpoint_index+self.stride) |
|
|
return subset_two |
|
|
|
|
|
|
|
|
def compute_fixed_attention_subset(self, word_index, tgt_len): |
|
|
|
|
|
if not self.is_bidirectional: |
|
|
absolute_max = word_index + 1 |
|
|
else: |
|
|
absolute_max = tgt_len |
|
|
|
|
|
|
|
|
rounded_index = math.floor((word_index + self.stride) / self.stride) * self.stride |
|
|
if word_index % self.stride == 0 and word_index != 0: |
|
|
subset_one = set(range(word_index-self.stride, min(absolute_max, word_index+1))) |
|
|
else: |
|
|
subset_one = set(range(max(0, rounded_index - self.stride), min( |
|
|
absolute_max, rounded_index+1)) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
subset_two = set() |
|
|
if not self.is_bidirectional: |
|
|
subset_two = self.compute_subset_summaries(absolute_max) |
|
|
|
|
|
return subset_one.union(subset_two) |
|
|
|
|
|
|
|
|
def buffered_sparse_mask(self, tensor, tgt_len, src_len): |
|
|
assert(tgt_len > self.stride) |
|
|
sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float('-inf')) |
|
|
|
|
|
|
|
|
subset_summaries = set() |
|
|
if self.is_bidirectional: |
|
|
subset_summaries = self.compute_subset_summaries(tgt_len) |
|
|
|
|
|
for i in range(tgt_len): |
|
|
fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len) |
|
|
fixed_attention_subset = fixed_attention_subset.union(subset_summaries) |
|
|
included_word_indices = torch.LongTensor(list(fixed_attention_subset)) |
|
|
sparse_mask[i].index_fill_(0, included_word_indices, 0) |
|
|
return sparse_mask.type_as(tensor) |
|
|
|
|
|
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): |
|
|
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len) |
|
|
sparse_mask = sparse_mask.unsqueeze(0).expand(bsz * self.num_heads, tgt_len, src_len) |
|
|
attn_weights += sparse_mask |
|
|
|