Spaces:
Runtime error
Runtime error
File size: 1,380 Bytes
eb55711 38e36bb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import torch.nn as nn
class DecoderRNN(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0):
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx)
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.init_h = nn.Linear(embed_size, hidden_size)
self.init_c = nn.Linear(embed_size, hidden_size)
def forward(self, features, captions, hidden=None):
if hidden == None:
h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
hidden = (h0, c0)
# dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size)
embeddings = self.embed(captions) # (B, seqlen) -> Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size)
outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size)
outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size)
return outputs, hidden |