File size: 1,380 Bytes
eb55711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38e36bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn


class DecoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx)
        self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)
    
    def forward(self, features, captions, hidden=None):
        if hidden == None:
            h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
            c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
            hidden = (h0, c0)
            # dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size)

        embeddings = self.embed(captions) # (B, seqlen) ->  Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size)
        outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size)
        outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size)
        return outputs, hidden