File size: 409 Bytes
c0741ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
import torch.nn as nn
class MiniText(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(256, 16)
self.gru = nn.GRU(16, 16, batch_first=True)
self.fc = nn.Linear(16, 256)
def forward(self, x, h=None):
x = self.embed(x)
out, h = self.gru(x, h)
logits = self.fc(out)
return logits, h
|