import torch import torch.nn as nn class CharRNN(nn.Module): """Character-level RNN language model with optional dense projection.""" def __init__( self, emb_in: int, emb_dim: int, hidden_dim: int = 128, gru_layers: int = 2, dropout: float = 0.1, dense_layer: bool = False, dense_dropout: float = 0.1, ): """ Args: emb_in (int): Vocabulary size. emb_dim (int): Embedding dimension. hidden_dim (int): Hidden state dimension of GRU. gru_layers (int): Number of GRU layers stacked. dropout (float): Dropout between GRU layers. dense_layer (bool): Whether to apply an extra dense projection. dense_dropout (float): Dropout rate for dense layer. """ super().__init__() self.Embedding = nn.Embedding(emb_in, emb_dim) self.GRU = nn.GRU( emb_dim, hidden_dim, gru_layers, batch_first=True, dropout=dropout ) self.dense_layer = dense_layer self.dense = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dense_dropout), ) self.output = nn.Linear(hidden_dim, emb_in) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (Tensor): Input tensor of shape (batch_size, seq_len). Returns: Tensor: Logits of shape (batch_size, emb_in, seq_len). """ x = self.Embedding(x) # (batch_size, seq_len, emb_dim) x, _ = self.GRU(x) # (batch_size, seq_len, hidden_dim) if self.dense_layer: x = self.dense(x) # (batch_size, seq_len, hidden_dim) logits = self.output(x) # (batch_size, seq_len, emb_in) return logits.permute(0, 2, 1) # (batch_size, emb_in, seq_len)