import torch import torch.nn as nn class MiniText(nn.Module): def __init__(self): super().__init__() self.embed = nn.Embedding(256, 16) self.gru = nn.GRU(16, 16, batch_first=True) self.fc = nn.Linear(16, 256) def forward(self, x, h=None): x = self.embed(x) out, h = self.gru(x, h) logits = self.fc(out) return logits, h