File size: 1,214 Bytes
e8aab00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size: int):
        super(BahdanauAttention, self).__init__()
        self.W1 = nn.Linear(hidden_size * 2, hidden_size)
        self.W2 = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, 1)
    
    def forward(
        self,
        decoder_hidden: torch.Tensor,
        encoder_outputs: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        hidden_expanded = decoder_hidden.unsqueeze(1)
        score = torch.tanh(
            self.W1(encoder_outputs) + self.W2(hidden_expanded)
        )
        attention_logits = self.V(score)
        
        if mask is not None:
            attention_logits = attention_logits.masked_fill(
                mask.unsqueeze(-1) == 0,
                -1e9
            )
        
        attention_weights = F.softmax(attention_logits, dim=1).squeeze(2)
        context = torch.bmm(
            attention_weights.unsqueeze(1),
            encoder_outputs
        ).squeeze(1)
        
        return context, attention_weights