import torch import torch.nn as nn from typing import Tuple class Encoder(nn.Module): def __init__( self, vocab_size: int, embed_size: int, hidden_size: int, num_layers: int = 2, dropout: float = 0.3 ): super(Encoder, self).__init__() self.vocab_size = vocab_size self.embed_size = embed_size self.hidden_size = hidden_size self.num_layers = num_layers self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=0 ) self.lstm = nn.LSTM( input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0, bidirectional=True ) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(hidden_size * 2) def forward( self, input_seq: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: embedded = self.dropout(self.embedding(input_seq)) packed_embedded = nn.utils.rnn.pack_padded_sequence( embedded, input_lengths.cpu(), batch_first=True, enforce_sorted=False ) packed_output, (hidden, cell) = self.lstm(packed_embedded) outputs, _ = nn.utils.rnn.pad_packed_sequence( packed_output, batch_first=True ) outputs = self.layer_norm(outputs) return outputs, hidden, cell