| |
| |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class RelativePositionBias(nn.Module): |
| def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12): |
| super().__init__() |
| self.bidirectional = bidirectional |
| self.num_buckets = num_buckets |
| self.max_distance = max_distance |
| self.n_heads = n_heads |
| self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) |
|
|
| @staticmethod |
| def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): |
| ret = 0 |
| n = -relative_position |
| if bidirectional: |
| num_buckets //= 2 |
| ret += (n < 0).to(torch.long) * num_buckets |
| n = torch.abs(n) |
| else: |
| n = torch.max(n, torch.zeros_like(n)) |
|
|
| max_exact = num_buckets // 2 |
| is_small = n < max_exact |
|
|
| val_if_large = max_exact + ( |
| torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) |
| ).to(torch.long) |
| val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) |
|
|
| ret += torch.where(is_small, n, val_if_large) |
| return ret |
|
|
| def compute_bias(self, qlen, klen, step=None): |
| step = 0 if step is None else step |
| context_position = torch.arange( |
| step, |
| step + qlen, |
| dtype=torch.long, |
| device=self.relative_attention_bias.weight.device, |
| )[:, None] |
| memory_position = torch.arange(klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[ |
| None, : |
| ] |
| relative_position = memory_position - context_position |
|
|
| rp_bucket = self._relative_position_bucket( |
| relative_position, |
| bidirectional=self.bidirectional, |
| num_buckets=self.num_buckets, |
| max_distance=self.max_distance, |
| ) |
| rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) |
| values = self.relative_attention_bias(rp_bucket) |
| values = values.permute([2, 0, 1]).unsqueeze(0) |
| return values |
|
|
| def forward(self, batch_size, qlen, klen, step=None): |
| |
| return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen) |
|
|