Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| from .multihead_attention import MultiheadAttention | |
| class SparseMultiheadAttention(MultiheadAttention): | |
| """Sparse Multi-Headed Attention. | |
| "Generating Long Sequences with Sparse Transformers". Implements | |
| fixed factorized self attention, where l=stride and c=expressivity. | |
| A(1) includes all words in the stride window and A(2) takes a summary of c | |
| words from the end of each stride window. | |
| If is_bidirectional=False, we do not include any words past the current word, | |
| as in the paper. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| kdim=None, | |
| vdim=None, | |
| dropout=0.0, | |
| bias=True, | |
| add_bias_kv=False, | |
| add_zero_attn=False, | |
| self_attention=False, | |
| encoder_decoder_attention=False, | |
| stride=32, | |
| expressivity=8, | |
| is_bidirectional=True, | |
| ): | |
| super().__init__( | |
| embed_dim, | |
| num_heads, | |
| kdim, | |
| vdim, | |
| dropout, | |
| bias, | |
| add_bias_kv, | |
| add_zero_attn, | |
| self_attention, | |
| encoder_decoder_attention, | |
| ) | |
| self.is_bidirectional = is_bidirectional | |
| self.stride = stride | |
| self.expressivity = expressivity | |
| assert self.stride > 0 and self.stride >= self.expressivity | |
| # Used for Ai(2) calculations - beginning of [l-c, l] range | |
| def compute_checkpoint(self, word_index): | |
| if word_index % self.stride == 0 and word_index != 0: | |
| checkpoint_index = word_index - self.expressivity | |
| else: | |
| checkpoint_index = ( | |
| math.floor(word_index / self.stride) * self.stride | |
| + self.stride | |
| - self.expressivity | |
| ) | |
| return checkpoint_index | |
| # Computes Ai(2) | |
| def compute_subset_summaries(self, absolute_max): | |
| checkpoint_index = self.compute_checkpoint(0) | |
| subset_two = set() | |
| while checkpoint_index <= absolute_max - 1: | |
| summary = set( | |
| range( | |
| checkpoint_index, | |
| min(checkpoint_index + self.expressivity + 1, absolute_max), | |
| ) | |
| ) | |
| subset_two = subset_two.union(summary) | |
| checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride) | |
| return subset_two | |
| # Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf | |
| def compute_fixed_attention_subset(self, word_index, tgt_len): | |
| # +1s account for range function; [min, max) -> [min, max] | |
| if not self.is_bidirectional: | |
| absolute_max = word_index + 1 | |
| else: | |
| absolute_max = tgt_len | |
| # Subset 1 - whole window | |
| rounded_index = ( | |
| math.floor((word_index + self.stride) / self.stride) * self.stride | |
| ) | |
| if word_index % self.stride == 0 and word_index != 0: | |
| subset_one = set( | |
| range(word_index - self.stride, min(absolute_max, word_index + 1)) | |
| ) | |
| else: | |
| subset_one = set( | |
| range( | |
| max(0, rounded_index - self.stride), | |
| min(absolute_max, rounded_index + 1), | |
| ) | |
| ) | |
| # Subset 2 - summary per window | |
| # If bidirectional, subset 2 is the same for every index | |
| subset_two = set() | |
| if not self.is_bidirectional: | |
| subset_two = self.compute_subset_summaries(absolute_max) | |
| return subset_one.union(subset_two) | |
| # Compute sparse mask - if bidirectional, can pre-compute and store | |
| def buffered_sparse_mask(self, tensor, tgt_len, src_len): | |
| assert tgt_len > self.stride | |
| sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf")) | |
| # If bidirectional, subset 2 is the same for every index | |
| subset_summaries = set() | |
| if self.is_bidirectional: | |
| subset_summaries = self.compute_subset_summaries(tgt_len) | |
| for i in range(tgt_len): | |
| fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len) | |
| fixed_attention_subset = fixed_attention_subset.union(subset_summaries) | |
| included_word_indices = torch.LongTensor(list(fixed_attention_subset)) | |
| sparse_mask[i].index_fill_(0, included_word_indices, 0) | |
| return sparse_mask.type_as(tensor) | |
| def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): | |
| sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len) | |
| sparse_mask = sparse_mask.unsqueeze(0).expand( | |
| bsz * self.num_heads, tgt_len, src_len | |
| ) | |
| attn_weights += sparse_mask | |