MiniText-v1.0-base / model.py
Arthur Samuel Galego Panucci FIgueiredo
Upload 5 files
c0741ab verified
raw
history blame contribute delete
409 Bytes
import torch
import torch.nn as nn
class MiniText(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(256, 16)
self.gru = nn.GRU(16, 16, batch_first=True)
self.fc = nn.Linear(16, 256)
def forward(self, x, h=None):
x = self.embed(x)
out, h = self.gru(x, h)
logits = self.fc(out)
return logits, h