Spaces:
Sleeping
Sleeping
File size: 1,214 Bytes
e8aab00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class BahdanauAttention(nn.Module):
def __init__(self, hidden_size: int):
super(BahdanauAttention, self).__init__()
self.W1 = nn.Linear(hidden_size * 2, hidden_size)
self.W2 = nn.Linear(hidden_size, hidden_size)
self.V = nn.Linear(hidden_size, 1)
def forward(
self,
decoder_hidden: torch.Tensor,
encoder_outputs: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_expanded = decoder_hidden.unsqueeze(1)
score = torch.tanh(
self.W1(encoder_outputs) + self.W2(hidden_expanded)
)
attention_logits = self.V(score)
if mask is not None:
attention_logits = attention_logits.masked_fill(
mask.unsqueeze(-1) == 0,
-1e9
)
attention_weights = F.softmax(attention_logits, dim=1).squeeze(2)
context = torch.bmm(
attention_weights.unsqueeze(1),
encoder_outputs
).squeeze(1)
return context, attention_weights |