Songyou's picture
add files
f3b11f9
raw
history blame
347 Bytes
import torch.nn as nn
import torch.nn.functional as F
class Generator(nn.Module):
"Define standard linear + softmax generation step."
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.proj = nn.Linear(d_model, vocab)
def forward(self, x):
return F.log_softmax(self.proj(x), dim=-1)