File size: 1,888 Bytes
b6447fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import torch
import torch.nn as nn
class CharRNN(nn.Module):
"""Character-level RNN language model with optional dense projection."""
def __init__(
self,
emb_in: int,
emb_dim: int,
hidden_dim: int = 128,
gru_layers: int = 2,
dropout: float = 0.1,
dense_layer: bool = False,
dense_dropout: float = 0.1,
):
"""
Args:
emb_in (int): Vocabulary size.
emb_dim (int): Embedding dimension.
hidden_dim (int): Hidden state dimension of GRU.
gru_layers (int): Number of GRU layers stacked.
dropout (float): Dropout between GRU layers.
dense_layer (bool): Whether to apply an extra dense projection.
dense_dropout (float): Dropout rate for dense layer.
"""
super().__init__()
self.Embedding = nn.Embedding(emb_in, emb_dim)
self.GRU = nn.GRU(
emb_dim, hidden_dim, gru_layers, batch_first=True, dropout=dropout
)
self.dense_layer = dense_layer
self.dense = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dense_dropout),
)
self.output = nn.Linear(hidden_dim, emb_in)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (batch_size, seq_len).
Returns:
Tensor: Logits of shape (batch_size, emb_in, seq_len).
"""
x = self.Embedding(x) # (batch_size, seq_len, emb_dim)
x, _ = self.GRU(x) # (batch_size, seq_len, hidden_dim)
if self.dense_layer:
x = self.dense(x) # (batch_size, seq_len, hidden_dim)
logits = self.output(x) # (batch_size, seq_len, emb_in)
return logits.permute(0, 2, 1) # (batch_size, emb_in, seq_len)
|