caption-gen / decoder.py
Sher1988's picture
Change structure of the project.
eb55711
import torch.nn as nn
class DecoderRNN(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0):
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx)
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.init_h = nn.Linear(embed_size, hidden_size)
self.init_c = nn.Linear(embed_size, hidden_size)
def forward(self, features, captions, hidden=None):
if hidden == None:
h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
hidden = (h0, c0)
# dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size)
embeddings = self.embed(captions) # (B, seqlen) -> Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size)
outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size)
outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size)
return outputs, hidden