Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Optional | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| embed_size: int, | |
| hidden_size: int, | |
| attention: nn.Module, | |
| num_layers: int = 2, | |
| dropout: float = 0.3 | |
| ): | |
| super(Decoder, self).__init__() | |
| self.vocab_size = vocab_size | |
| self.embed_size = embed_size | |
| self.hidden_size = hidden_size | |
| self.attention = attention | |
| self.embedding = nn.Embedding( | |
| num_embeddings=vocab_size, | |
| embedding_dim=embed_size, | |
| padding_idx=0 | |
| ) | |
| self.lstm = nn.LSTM( | |
| input_size=embed_size + hidden_size * 2, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout if num_layers > 1 else 0 | |
| ) | |
| self.fc_out = nn.Linear( | |
| hidden_size + hidden_size * 2 + embed_size, | |
| vocab_size | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = nn.LayerNorm(hidden_size + hidden_size * 2 + embed_size) | |
| def forward( | |
| self, | |
| input_token: torch.Tensor, | |
| decoder_hidden: torch.Tensor, | |
| decoder_cell: torch.Tensor, | |
| encoder_outputs: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| embedded = self.dropout(self.embedding(input_token.unsqueeze(1))) | |
| top_hidden = decoder_hidden[-1] | |
| context, attention_weights = self.attention( | |
| top_hidden, encoder_outputs, mask | |
| ) | |
| context = self.dropout(context) | |
| lstm_input = torch.cat((embedded, context.unsqueeze(1)), dim=2) | |
| output, (decoder_hidden, decoder_cell) = self.lstm( | |
| lstm_input, | |
| (decoder_hidden, decoder_cell) | |
| ) | |
| output = output.squeeze(1) | |
| embedded = embedded.squeeze(1) | |
| output_context = torch.cat((output, context, embedded), dim=1) | |
| output_context = self.layer_norm(output_context) | |
| prediction = self.fc_out(output_context) | |
| return prediction, decoder_hidden, decoder_cell, attention_weights | |