Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class VQAModel(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers): | |
| super().__init__() | |
| self.cnn = models.resnet18(weights="DEFAULT") | |
| self.cnn.fc = nn.Identity() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True) | |
| self.fc1 = nn.Linear(512 + hidden_dim, 256) | |
| self.relu = nn.ReLU() | |
| self.fc2 = nn.Linear(256, num_answers) | |
| def forward(self, image, question): | |
| img_feat = self.cnn(image) | |
| q_embed = self.embedding(question) | |
| _, (h, _) = self.lstm(q_embed) | |
| q_feat = h.squeeze(0) | |
| x = self.relu(self.fc1(torch.cat((img_feat, q_feat), dim=1))) | |
| return self.fc2(x) | |