| import torch |
| import torch.nn as nn |
|
|
| class CharRNN(nn.Module): |
| """Character-level RNN language model with optional dense projection.""" |
|
|
| def __init__( |
| self, |
| emb_in: int, |
| emb_dim: int, |
| hidden_dim: int = 128, |
| gru_layers: int = 2, |
| dropout: float = 0.1, |
| dense_layer: bool = False, |
| dense_dropout: float = 0.1, |
| ): |
| """ |
| Args: |
| emb_in (int): Vocabulary size. |
| emb_dim (int): Embedding dimension. |
| hidden_dim (int): Hidden state dimension of GRU. |
| gru_layers (int): Number of GRU layers stacked. |
| dropout (float): Dropout between GRU layers. |
| dense_layer (bool): Whether to apply an extra dense projection. |
| dense_dropout (float): Dropout rate for dense layer. |
| """ |
| super().__init__() |
| self.Embedding = nn.Embedding(emb_in, emb_dim) |
| self.GRU = nn.GRU( |
| emb_dim, hidden_dim, gru_layers, batch_first=True, dropout=dropout |
| ) |
| self.dense_layer = dense_layer |
| self.dense = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dense_dropout), |
| ) |
| self.output = nn.Linear(hidden_dim, emb_in) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x (Tensor): Input tensor of shape (batch_size, seq_len). |
| |
| Returns: |
| Tensor: Logits of shape (batch_size, emb_in, seq_len). |
| """ |
| x = self.Embedding(x) |
| x, _ = self.GRU(x) |
| if self.dense_layer: |
| x = self.dense(x) |
| logits = self.output(x) |
| return logits.permute(0, 2, 1) |
|
|