| import torch | |
| import torch.nn as nn | |
| class MiniText(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.embed = nn.Embedding(256, 16) | |
| self.gru = nn.GRU(16, 16, batch_first=True) | |
| self.fc = nn.Linear(16, 256) | |
| def forward(self, x, h=None): | |
| x = self.embed(x) | |
| out, h = self.gru(x, h) | |
| logits = self.fc(out) | |
| return logits, h | |