import torch.nn as nn class DecoderRNN(nn.Module): def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0): super(DecoderRNN, self).__init__() self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx) self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) self.init_h = nn.Linear(embed_size, hidden_size) self.init_c = nn.Linear(embed_size, hidden_size) def forward(self, features, captions, hidden=None): if hidden == None: h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1) c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1) hidden = (h0, c0) # dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size) embeddings = self.embed(captions) # (B, seqlen) -> Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size) outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size) outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size) return outputs, hidden