| """Global attention modules (Luong / Bahdanau)""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from onmt.modules.sparse_activations import sparsemax |
| from onmt.utils.misc import sequence_mask |
|
|
| |
| |
| |
| |
|
|
|
|
| class GlobalAttention(nn.Module): |
| r""" |
| Global attention takes a matrix and a query vector. It |
| then computes a parameterized convex combination of the matrix |
| based on the input query. |
| |
| Constructs a unit mapping a query `q` of size `dim` |
| and a source matrix `H` of size `n x dim`, to an output |
| of size `dim`. |
| |
| |
| .. mermaid:: |
| |
| graph BT |
| A[Query] |
| subgraph RNN |
| C[H 1] |
| D[H 2] |
| E[H N] |
| end |
| F[Attn] |
| G[Output] |
| A --> F |
| C --> F |
| D --> F |
| E --> F |
| C -.-> G |
| D -.-> G |
| E -.-> G |
| F --> G |
| |
| All models compute the output as |
| :math:`c = \sum_{j=1}^{\text{SeqLength}} a_j H_j` where |
| :math:`a_j` is the softmax of a score function. |
| Then then apply a projection layer to [q, c]. |
| |
| However they |
| differ on how they compute the attention score. |
| |
| * Luong Attention (dot, general): |
| * dot: :math:`\text{score}(H_j,q) = H_j^T q` |
| * general: :math:`\text{score}(H_j, q) = H_j^T W_a q` |
| |
| |
| * Bahdanau Attention (mlp): |
| * :math:`\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)` |
| |
| |
| Args: |
| dim (int): dimensionality of query and key |
| coverage (bool): use coverage term |
| attn_type (str): type of attention to use, options [dot,general,mlp] |
| attn_func (str): attention function to use, options [softmax,sparsemax] |
| |
| """ |
|
|
| def __init__(self, dim, coverage=False, attn_type="dot", attn_func="softmax"): |
| super(GlobalAttention, self).__init__() |
|
|
| self.dim = dim |
| assert attn_type in [ |
| "dot", |
| "general", |
| "mlp", |
| ], "Please select a valid attention type (got {:s}).".format(attn_type) |
| self.attn_type = attn_type |
| assert attn_func in [ |
| "softmax", |
| "sparsemax", |
| ], "Please select a valid attention function." |
| self.attn_func = attn_func |
|
|
| if self.attn_type == "general": |
| self.linear_in = nn.Linear(dim, dim, bias=False) |
| elif self.attn_type == "mlp": |
| self.linear_context = nn.Linear(dim, dim, bias=False) |
| self.linear_query = nn.Linear(dim, dim, bias=True) |
| self.v = nn.Linear(dim, 1, bias=False) |
| |
| out_bias = self.attn_type == "mlp" |
| self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) |
|
|
| if coverage: |
| self.linear_cover = nn.Linear(1, dim, bias=False) |
|
|
| def score(self, h_t, h_s): |
| """ |
| Args: |
| h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)`` |
| h_s (FloatTensor): sequence of sources ``(batch, src_len, dim`` |
| |
| Returns: |
| FloatTensor: raw attention scores (unnormalized) for each src index |
| ``(batch, tgt_len, src_len)`` |
| """ |
| src_batch, src_len, src_dim = h_s.size() |
| tgt_batch, tgt_len, tgt_dim = h_t.size() |
|
|
| if self.attn_type in ["general", "dot"]: |
| if self.attn_type == "general": |
| h_t = self.linear_in(h_t) |
| h_s_ = h_s.transpose(1, 2) |
| |
| return torch.bmm(h_t, h_s_) |
| else: |
| dim = self.dim |
| wq = self.linear_query(h_t) |
| wq = wq.view(tgt_batch, tgt_len, 1, dim) |
| wq = wq.expand(tgt_batch, tgt_len, src_len, dim) |
|
|
| uh = self.linear_context(h_s.contiguous()) |
| uh = uh.view(src_batch, 1, src_len, dim) |
| uh = uh.expand(src_batch, tgt_len, src_len, dim) |
|
|
| |
| wquh = torch.tanh(wq + uh) |
|
|
| return self.v(wquh).view(tgt_batch, tgt_len, src_len) |
|
|
| def forward(self, src, enc_out, src_len=None, coverage=None): |
| """ |
| |
| Args: |
| src (FloatTensor): query vectors ``(batch, tgt_len, dim)`` |
| enc_out (FloatTensor): encoder out vectors ``(batch, src_len, dim)`` |
| src_len (LongTensor): source context lengths ``(batch,)`` |
| coverage (FloatTensor): None (not supported yet) |
| |
| Returns: |
| (FloatTensor, FloatTensor): |
| |
| * Computed vector ``(batch, tgt_len, dim)`` |
| * Attention distribtutions for each query |
| ``(batch, tgt_len, src_len)`` |
| """ |
|
|
| |
| if src.dim() == 2: |
| one_step = True |
| src = src.unsqueeze(1) |
| else: |
| one_step = False |
|
|
| batch, src_l, dim = enc_out.size() |
| batch_, target_l, dim_ = src.size() |
| if coverage is not None: |
| batch_, src_l_ = coverage.size() |
|
|
| if coverage is not None: |
| cover = coverage.view(-1).unsqueeze(1) |
| enc_out += self.linear_cover(cover).view_as(enc_out) |
| enc_out = torch.tanh(enc_out) |
|
|
| |
| align = self.score(src, enc_out) |
|
|
| if src_len is not None: |
| mask = sequence_mask(src_len, max_len=align.size(-1)) |
| mask = mask.unsqueeze(1) |
| align.masked_fill_(~mask, -float("inf")) |
|
|
| |
| if self.attn_func == "softmax": |
| align_vectors = F.softmax(align.view(batch * target_l, src_l), -1) |
| else: |
| align_vectors = sparsemax(align.view(batch * target_l, src_l), -1) |
| align_vectors = align_vectors.view(batch, target_l, src_l) |
|
|
| |
| |
| c = torch.bmm(align_vectors, enc_out) |
|
|
| |
| concat_c = torch.cat([c, src], 2).view(batch * target_l, dim * 2) |
| attn_h = self.linear_out(concat_c).view(batch, target_l, dim) |
| if self.attn_type in ["general", "dot"]: |
| attn_h = torch.tanh(attn_h) |
|
|
| if one_step: |
| attn_h = attn_h.squeeze(1) |
| align_vectors = align_vectors.squeeze(1) |
|
|
| return attn_h, align_vectors |
|
|