maotao / fairseq /modules /sparse_multihead_attention.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from .multihead_attention import MultiheadAttention
class SparseMultiheadAttention(MultiheadAttention):
""" Sparse Multi-Headed Attention.
"Generating Long Sequences with Sparse Transformers". Implements
fixed factorized self attention, where l=stride and c=expressivity.
A(1) includes all words in the stride window and A(2) takes a summary of c
words from the end of each stride window.
If is_bidirectional=False, we do not include any words past the current word,
as in the paper.
"""
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False, stride=32, expressivity=8, is_bidirectional=True):
super().__init__(
embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv,
add_zero_attn, self_attention, encoder_decoder_attention
)
self.is_bidirectional = is_bidirectional
self.stride = stride
self.expressivity = expressivity
assert(self.stride > 0 and self.stride >= self.expressivity)
# Used for Ai(2) calculations - beginning of [l-c, l] range
def compute_checkpoint(self, word_index):
if word_index % self.stride == 0 and word_index != 0:
checkpoint_index = word_index - self.expressivity
else:
checkpoint_index = (
math.floor(word_index / self.stride) * self.stride
+ self.stride - self.expressivity
)
return checkpoint_index
# Computes Ai(2)
def compute_subset_summaries(self, absolute_max):
checkpoint_index = self.compute_checkpoint(0)
subset_two = set()
while checkpoint_index <= absolute_max-1:
summary = set(range(checkpoint_index, min(
checkpoint_index+self.expressivity+1, absolute_max)
))
subset_two = subset_two.union(summary)
checkpoint_index = self.compute_checkpoint(checkpoint_index+self.stride)
return subset_two
# Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf
def compute_fixed_attention_subset(self, word_index, tgt_len):
# +1s account for range function; [min, max) -> [min, max]
if not self.is_bidirectional:
absolute_max = word_index + 1
else:
absolute_max = tgt_len
# Subset 1 - whole window
rounded_index = math.floor((word_index + self.stride) / self.stride) * self.stride
if word_index % self.stride == 0 and word_index != 0:
subset_one = set(range(word_index-self.stride, min(absolute_max, word_index+1)))
else:
subset_one = set(range(max(0, rounded_index - self.stride), min(
absolute_max, rounded_index+1))
)
# Subset 2 - summary per window
# If bidirectional, subset 2 is the same for every index
subset_two = set()
if not self.is_bidirectional:
subset_two = self.compute_subset_summaries(absolute_max)
return subset_one.union(subset_two)
# Compute sparse mask - if bidirectional, can pre-compute and store
def buffered_sparse_mask(self, tensor, tgt_len, src_len):
assert(tgt_len > self.stride)
sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float('-inf'))
# If bidirectional, subset 2 is the same for every index
subset_summaries = set()
if self.is_bidirectional:
subset_summaries = self.compute_subset_summaries(tgt_len)
for i in range(tgt_len):
fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len)
fixed_attention_subset = fixed_attention_subset.union(subset_summaries)
included_word_indices = torch.LongTensor(list(fixed_attention_subset))
sparse_mask[i].index_fill_(0, included_word_indices, 0)
return sparse_mask.type_as(tensor)
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len)
sparse_mask = sparse_mask.unsqueeze(0).expand(bsz * self.num_heads, tgt_len, src_len)
attn_weights += sparse_mask