| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
|
|
| import torch |
|
|
| __all__ = ['ALiBiRelativePositionEmbedding'] |
|
|
|
|
| def get_slopes(n): |
| def get_slopes_power_of_2(n): |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
| ratio = start |
| return [start * ratio ** i for i in range(n)] |
|
|
| if math.log2(n).is_integer(): |
| slopes = get_slopes_power_of_2(n) |
| else: |
| closest_power_of_2 = 2 ** math.floor(math.log2(n)) |
| slopes = ( |
| get_slopes_power_of_2(closest_power_of_2) |
| + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] |
| ) |
|
|
| return slopes |
|
|
|
|
| def build_slopes(num_attention_heads, num_attention_heads_alibi): |
| """ |
| Builds a slopes tensor. |
| """ |
| slopes = torch.Tensor( |
| get_slopes(num_attention_heads_alibi) + [0] * (num_attention_heads - num_attention_heads_alibi) |
| ).cuda() |
| return slopes.unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
| def build_relative_position(query_length, key_length, num_attention_heads): |
| context_position = torch.arange(query_length)[:, None].cuda() |
| memory_position = torch.arange(key_length)[None, :].cuda() |
| |
| relative_position = memory_position - context_position |
|
|
| |
| relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_attention_heads, -1, -1) |
|
|
| return relative_position |
|
|
|
|
| class ALiBiRelativePositionEmbedding(torch.nn.Module): |
| """ |
| ALiBi (Attention with Linear Biases) relative position embedding for auto-regressive decoder |
| and joint encoder (symmetric for forward and backward distance). |
| Based on https://arxiv.org/bas/2108.12409 |
| """ |
|
|
| def __init__( |
| self, bidirectional, num_attention_heads, layer_type, num_attention_heads_alibi=None, max_seq_len=512 |
| ): |
| """ |
| Args: |
| bidirectional: Whether to use bidirectional relative position embedding |
| num_attention_heads: Number of attention heads |
| layer_type: Layer type. Can be one of [LayerType.encoder or LayerType.decoder]. Willdetermine the bias construction |
| num_attention_heads_alibi: Number of attention heads for which alibi bias will be used |
| max_seq_len: Maximum sequence length for precomputed relative positions. Larger sizes will result in more memory usage by computing alibi mask on-the-fly. |
| """ |
| super().__init__() |
|
|
| if (num_attention_heads_alibi is None) or (num_attention_heads_alibi <= 0): |
| num_attention_heads_alibi = num_attention_heads |
|
|
| if num_attention_heads_alibi > num_attention_heads: |
| raise ValueError( |
| f"num_attention_heads_alibi ({num_attention_heads_alibi}) cannot be larger than num_attention_heads ({num_attention_heads})" |
| ) |
|
|
| self.bidirectional = bidirectional |
| self.num_attention_heads = num_attention_heads |
| |
| self.layer_type = layer_type |
| |
| |
| self.num_attention_heads_alibi = num_attention_heads_alibi |
| |
| self.max_seq_len = max_seq_len |
|
|
| |
| self.slopes = build_slopes(num_attention_heads, num_attention_heads_alibi) |
| |
| self.relative_position = build_relative_position(max_seq_len, max_seq_len, num_attention_heads) |
|
|
| def forward(self, query_seq_length, key_seq_length): |
| |
| max_seq_len = max(query_seq_length, key_seq_length) |
| if max_seq_len > self.max_seq_len: |
| relative_position = build_relative_position(max_seq_len, max_seq_len, self.num_attention_heads) |
| else: |
| relative_position = self.relative_position |
| |
| relative_position = relative_position[:, :query_seq_length, :key_seq_length] |
| |
| if not self.bidirectional: |
| relative_position = torch.tril(relative_position) |
|
|
| |
| return -relative_position.unsqueeze(0) * self.slopes |
|
|