File size: 2,697 Bytes
b1a427a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import torchvision.models as models

class EncoderCNN(nn.Module):
    def __init__(self,embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)
        self.batch= nn.BatchNorm1d(embed_size,momentum = 0.01)
        self.embed.weight.data.normal_(0., 0.02)
        self.embed.bias.data.fill_(0)
        
    def forward(self,images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.batch(self.embed(features))
        return features
    
    
class DecoderRNN(nn.Module):
    def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
        super(DecoderRNN, self).__init__()
        self.embed=nn.Embedding(vocab_size,embed_size)
        self.lstm=nn.LSTM(embed_size,hidden_size,num_layers)
        self.linear=nn.Linear(hidden_size,vocab_size)
        self.dropout=nn.Dropout(0.5)
        
    def forward(self,features,captions):
        embeddings=self.dropout(self.embed(captions))
        embeddings=torch.cat((features.unsqueeze(0),embeddings),dim=0)
        hiddens,_=self.lstm(embeddings)
        outputs=self.linear(hiddens)
        
        return outputs
    
class CNNtoRNN(nn.Module):
    def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
        super(CNNtoRNN,self).__init__()
        self.encoderCNN=EncoderCNN(embed_size)
        self.decoderRNN=DecoderRNN(embed_size,hidden_size,vocab_size,num_layers)
        
    def forward(self,images,captions):
        features=self.encoderCNN(images)
        outputs=self.decoderRNN(features,captions)
        return outputs
    
    def caption_image(self,image,vocabulary,max_length=50):
        result_caption=[]
        with torch.no_grad():
            X=self.encoderCNN(image).unsqueeze(0)
            states=None
            
            for _ in range(max_length):
                hiddens,states=self.decoderRNN.lstm(X,states)
                output=self.decoderRNN.linear(hiddens.squeeze(0))
                predicted=output.argmax(1)
                result_caption.append(predicted.item())
                
                X=self.decoderRNN.embed(predicted).unsqueeze(0)
                
                if vocabulary.itos[predicted.item()]=="<EOS>":
                    break
                
        return [vocabulary.itos[idx] for idx in result_caption]