| import torch |
| import numpy as np |
| import math |
|
|
|
|
| class TriangularCausalMask: |
| def __init__(self, B, L, device="cpu"): |
| mask_shape = [B, 1, L, L] |
| with torch.no_grad(): |
| self._mask = torch.triu( |
| torch.ones(mask_shape, dtype=torch.bool), diagonal=1 |
| ).to(device) |
|
|
| @property |
| def mask(self): |
| return self._mask |
|
|
| class QuestionMask: |
| def __init__(self, B, L, device="cpu"): |
| mask_shape = [B, 1, L, L] |
| with torch.no_grad(): |
| self._mask = torch.zeros(mask_shape, dtype=torch.bool).to(device) |
| self._mask[:,:,:-1,-1] = True |
|
|
| @property |
| def mask(self): |
| return self._mask |
|
|
|
|
| class ProbMask: |
| def __init__(self, B, H, L, index, scores, device="cpu"): |
| _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) |
| _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) |
| indicator = _mask_ex[ |
| torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : |
| ].to(device) |
| self._mask = indicator.view(scores.shape).to(device) |
|
|
| @property |
| def mask(self): |
| return self._mask |
|
|
|
|
| class LocalMask: |
| def __init__(self, B, L, S, device="cpu"): |
| mask_shape = [B, 1, L, S] |
| with torch.no_grad(): |
| self.len = math.ceil(np.log2(L)) |
| self._mask1 = torch.triu( |
| torch.ones(mask_shape, dtype=torch.bool), diagonal=1 |
| ).to(device) |
| self._mask2 = ~torch.triu( |
| torch.ones(mask_shape, dtype=torch.bool), diagonal=-self.len |
| ).to(device) |
| self._mask = self._mask1 + self._mask2 |
|
|
| @property |
| def mask(self): |
| return self._mask |
|
|