|
|
import math |
|
|
import torch |
|
|
|
|
|
def get_alibi( |
|
|
max_positions: int, |
|
|
attention_heads: int, |
|
|
): |
|
|
def get_slopes(n): |
|
|
def get_slopes_power_of_2(n): |
|
|
start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
|
|
ratio = start |
|
|
return [start * ratio ** i for i in range(n)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if math.log2(n).is_integer(): |
|
|
return get_slopes_power_of_2(n) |
|
|
else: |
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(n)) |
|
|
return ( |
|
|
get_slopes_power_of_2(closest_power_of_2) |
|
|
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] |
|
|
) |
|
|
|
|
|
maxpos = max_positions |
|
|
attn_heads = attention_heads |
|
|
slopes = torch.Tensor(get_slopes(attn_heads)) |
|
|
|
|
|
|
|
|
|
|
|
pos_bias = ( |
|
|
torch.abs( |
|
|
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1) |
|
|
) |
|
|
* -1 |
|
|
) |
|
|
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand( |
|
|
attn_heads, -1, -1 |
|
|
) |
|
|
return alibi_bias |
|
|
|
|
|
def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T): |
|
|
alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T) |
|
|
H = alibi_bias.size(1) |
|
|
alibi_mask = mask_indices.unsqueeze(1) |
|
|
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1)) |
|
|
alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T) |
|
|
M = alibi_bias.size(-2) |
|
|
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2)) |
|
|
alibi_bias = alibi_bias.view(-1, M, M) |
|
|
return alibi_bias |
|
|
|
|
|
|
|
|
|