PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
6789f6f verified
raw
history blame
1.99 kB
import math
import torch
def get_alibi(
max_positions: int,
attention_heads: int,
):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
# In the paper, we only train models that have 2^a heads for some
# a. This function has some good properties that only occur when
# the input is a power of 2. To maintain that even when the number
# of heads is not a power of 2, we use this workaround.
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
maxpos = max_positions
attn_heads = attention_heads
slopes = torch.Tensor(get_slopes(attn_heads))
# prepare alibi position linear bias. Note that wav2vec2 is non
# autoregressive model so we want a symmetric mask with 0 on the
# diagonal and other wise linear decreasing valuees
pos_bias = (
torch.abs(
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
)
* -1
)
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
attn_heads, -1, -1
)
return alibi_bias
def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T):
alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T)
H = alibi_bias.size(1)
alibi_mask = mask_indices.unsqueeze(1)
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1))
alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T)
M = alibi_bias.size(-2)
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2))
alibi_bias = alibi_bias.view(-1, M, M)
return alibi_bias