| """ Alibi position bias """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def pad_at_dim(t, pad, dim=-1, value=0.0): |
| dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) |
| zeros = (0, 0) * dims_from_right |
| return F.pad(t, (*zeros, *pad), value=value) |
|
|
|
|
| class AlibiPositionalBias(nn.Module): |
| def __init__(self, heads, **kwargs): |
| super().__init__() |
| self.heads = heads |
| slopes = torch.Tensor(self._get_slopes(heads)) |
| slopes = slopes.unsqueeze(1).unsqueeze(1) |
| self.register_buffer("slopes", slopes, persistent=False) |
| self.register_buffer("bias", None, persistent=False) |
|
|
| def get_bias(self, i, j, device): |
| i_arange = torch.arange(j - i, j, device=device) |
| j_arange = torch.arange(j, device=device) |
| bias = -torch.abs( |
| j_arange.unsqueeze(0).unsqueeze(0) - i_arange.unsqueeze(1).unsqueeze(0) |
| ) |
| return bias |
|
|
| @staticmethod |
| def _get_slopes(heads): |
| def get_slopes_power_of_2(n): |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
| ratio = start |
| return [start * ratio**i for i in range(n)] |
|
|
| if math.log2(heads).is_integer(): |
| return get_slopes_power_of_2(heads) |
| closest_power_of_2 = 2 ** math.floor(math.log2(heads)) |
|
|
| return ( |
| get_slopes_power_of_2(closest_power_of_2) |
| + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ |
| : heads - closest_power_of_2 |
| ] |
| ) |
|
|
| def forward(self, qk_dots): |
| h, i, j, device = *qk_dots.shape[-3:], qk_dots.device |
|
|
| if (self.bias is not None) and self.bias.shape[-1] >= j: |
| return qk_dots + self.bias[..., :i, :j] |
|
|
| bias = self.get_bias(i, j, device) |
| bias = bias * self.slopes |
|
|
| num_heads_unalibied = h - bias.shape[0] |
| bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=0) |
| self.register_buffer("bias", bias, persistent=False) |
|
|
| return qk_dots + self.bias |
|
|