mohamedahraf273's picture
add generator
e8aab00
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class BahdanauAttention(nn.Module):
def __init__(self, hidden_size: int):
super(BahdanauAttention, self).__init__()
self.W1 = nn.Linear(hidden_size * 2, hidden_size)
self.W2 = nn.Linear(hidden_size, hidden_size)
self.V = nn.Linear(hidden_size, 1)
def forward(
self,
decoder_hidden: torch.Tensor,
encoder_outputs: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_expanded = decoder_hidden.unsqueeze(1)
score = torch.tanh(
self.W1(encoder_outputs) + self.W2(hidden_expanded)
)
attention_logits = self.V(score)
if mask is not None:
attention_logits = attention_logits.masked_fill(
mask.unsqueeze(-1) == 0,
-1e9
)
attention_weights = F.softmax(attention_logits, dim=1).squeeze(2)
context = torch.bmm(
attention_weights.unsqueeze(1),
encoder_outputs
).squeeze(1)
return context, attention_weights