Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple, Optional | |
| class BahdanauAttention(nn.Module): | |
| def __init__(self, hidden_size: int): | |
| super(BahdanauAttention, self).__init__() | |
| self.W1 = nn.Linear(hidden_size * 2, hidden_size) | |
| self.W2 = nn.Linear(hidden_size, hidden_size) | |
| self.V = nn.Linear(hidden_size, 1) | |
| def forward( | |
| self, | |
| decoder_hidden: torch.Tensor, | |
| encoder_outputs: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| hidden_expanded = decoder_hidden.unsqueeze(1) | |
| score = torch.tanh( | |
| self.W1(encoder_outputs) + self.W2(hidden_expanded) | |
| ) | |
| attention_logits = self.V(score) | |
| if mask is not None: | |
| attention_logits = attention_logits.masked_fill( | |
| mask.unsqueeze(-1) == 0, | |
| -1e9 | |
| ) | |
| attention_weights = F.softmax(attention_logits, dim=1).squeeze(2) | |
| context = torch.bmm( | |
| attention_weights.unsqueeze(1), | |
| encoder_outputs | |
| ).squeeze(1) | |
| return context, attention_weights |