import torch import torch.nn as nn from typing import Tuple, Optional class Decoder(nn.Module): def __init__( self, vocab_size: int, embed_size: int, hidden_size: int, attention: nn.Module, num_layers: int = 2, dropout: float = 0.3 ): super(Decoder, self).__init__() self.vocab_size = vocab_size self.embed_size = embed_size self.hidden_size = hidden_size self.attention = attention self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=0 ) self.lstm = nn.LSTM( input_size=embed_size + hidden_size * 2, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0 ) self.fc_out = nn.Linear( hidden_size + hidden_size * 2 + embed_size, vocab_size ) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(hidden_size + hidden_size * 2 + embed_size) def forward( self, input_token: torch.Tensor, decoder_hidden: torch.Tensor, decoder_cell: torch.Tensor, encoder_outputs: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: embedded = self.dropout(self.embedding(input_token.unsqueeze(1))) top_hidden = decoder_hidden[-1] context, attention_weights = self.attention( top_hidden, encoder_outputs, mask ) context = self.dropout(context) lstm_input = torch.cat((embedded, context.unsqueeze(1)), dim=2) output, (decoder_hidden, decoder_cell) = self.lstm( lstm_input, (decoder_hidden, decoder_cell) ) output = output.squeeze(1) embedded = embedded.squeeze(1) output_context = torch.cat((output, context, embedded), dim=1) output_context = self.layer_norm(output_context) prediction = self.fc_out(output_context) return prediction, decoder_hidden, decoder_cell, attention_weights