| from typing import ( |
| Optional, |
| ) |
| import math |
|
|
| import torch as T |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| import opt_einsum as oe |
|
|
| from torch import Tensor |
|
|
| einsum = oe.contract |
|
|
|
|
| def masked_softmax(xs: Tensor, mask: Tensor, dim: int = -1, eps=1e-12): |
| xs = xs.masked_fill(~mask, -1e9) |
| xs = F.softmax(xs, dim=dim) |
| return xs |
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| kind: str, |
| query_dim: int, |
| input_dim: int, |
| output_dim: int = None, |
| activation: str = 'auto', |
| scaled = True, |
| ): |
| super().__init__() |
| assert kind in [ |
| 'dot', |
| 'linear', |
| ] |
|
|
| self.kind = kind |
| self.Dq = query_dim |
| self.Din = input_dim |
| self.Dout = output_dim or self.Din |
| self.activation = 'auto' |
| self.scaled = scaled |
|
|
| self.Wq_ = nn.Linear(self.Dq, self.Din) |
| self.Wk_ = nn.Linear(self.Din, self.Din) |
| self.Wv_ = nn.Linear(self.Din, self.Dout) |
| self.Wz_ = nn.Linear(self.Din, self.Dout) |
|
|
| def forward( |
| self, |
| query: Tensor, |
| data: Tensor, |
| content_mask: Optional[Tensor] = None, |
| prejudice_mask: Optional[Tensor] = None, |
| ): |
| |
| |
| |
| |
| |
|
|
| dimB, dimS, dimW, dimI = query.shape |
|
|
| |
| qs = self.Wq_(query) |
| ks = self.Wk_(data) |
| vs = self.Wv_(data) |
|
|
| if content_mask is not None: |
| words_mask = content_mask.any(2) |
| |
| else: |
| words_mask = qs.new_ones((dimB, dimS)) |
|
|
| if self.kind == 'linear': |
| |
| assert prejudice_mask is None, "Linear mode does not support prejudice_mask." |
| assert content_mask is not None, "Linear mode requires a content_mask." |
| qs = T.relu(qs) * content_mask.unsqueeze(3) |
| |
| ks = T.relu(ks) * words_mask.unsqueeze(2) |
| |
| vks = einsum("bsi, bsz -> bzi", ks, vs) |
| |
| zs = einsum("bswi, bzi -> bswz", qs, vks) |
| |
| if self.scaled: |
| ks = ks.sum(1) |
| |
| denom = einsum("bswi, bi -> bsw", qs, ks) + 1e-9 |
| zs = zs / denom |
|
|
| elif self.kind == 'dot': |
| |
| |
| |
| att_map = einsum("bqwi, bki -> bqkw", qs, ks) |
| |
| if self.scaled == 'seqlen': |
| att_map_ndim = len(att_map.shape) - 1 |
| norm_coeff = words_mask.sum(1).view(-1, *([1] * att_map_ndim)) |
| |
| att_map = att_map / T.sqrt(norm_coeff.float()) |
| else: |
| att_map = att_map / math.sqrt(self.Din) |
|
|
| if content_mask is None and prejudice_mask is None: |
| att_map = F.softmax(att_map, dim=2) |
| else: |
| if content_mask is None: |
| assert prejudice_mask is not None |
| qk_mask = prejudice_mask.unsqueeze(3) |
| |
| elif prejudice_mask is None: |
| qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2) |
| |
| else: |
| qk_mask = words_mask.unsqueeze(1).unsqueeze(3) |
| |
| qk_mask = qk_mask * prejudice_mask.unsqueeze(3) |
| |
|
|
| att_map = masked_softmax(att_map, qk_mask.bool(), dim=2) |
|
|
| |
| zs = einsum("bqkw, bkz -> bqwz", att_map, vs) |
|
|
| zs = self.Wz_(zs) |
| return zs, att_map |