File size: 409 Bytes
c0741ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn

class MiniText(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(256, 16)
        self.gru = nn.GRU(16, 16, batch_first=True)
        self.fc = nn.Linear(16, 256)

    def forward(self, x, h=None):
        x = self.embed(x)
        out, h = self.gru(x, h)
        logits = self.fc(out)
        return logits, h