import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional class BahdanauAttention(nn.Module): def __init__(self, hidden_size: int): super(BahdanauAttention, self).__init__() self.W1 = nn.Linear(hidden_size * 2, hidden_size) self.W2 = nn.Linear(hidden_size, hidden_size) self.V = nn.Linear(hidden_size, 1) def forward( self, decoder_hidden: torch.Tensor, encoder_outputs: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: hidden_expanded = decoder_hidden.unsqueeze(1) score = torch.tanh( self.W1(encoder_outputs) + self.W2(hidden_expanded) ) attention_logits = self.V(score) if mask is not None: attention_logits = attention_logits.masked_fill( mask.unsqueeze(-1) == 0, -1e9 ) attention_weights = F.softmax(attention_logits, dim=1).squeeze(2) context = torch.bmm( attention_weights.unsqueeze(1), encoder_outputs ).squeeze(1) return context, attention_weights