File size: 837 Bytes
1e5f3d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import torchvision.models as models

class VQAModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
        super().__init__()

        self.cnn = models.resnet18(weights="DEFAULT")
        self.cnn.fc = nn.Identity()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

        self.fc1 = nn.Linear(512 + hidden_dim, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, num_answers)

    def forward(self, image, question):
        img_feat = self.cnn(image)

        q_embed = self.embedding(question)
        _, (h, _) = self.lstm(q_embed)
        q_feat = h.squeeze(0)

        x = self.relu(self.fc1(torch.cat((img_feat, q_feat), dim=1)))
        return self.fc2(x)