import torch import torch.nn as nn import torchvision.models as models class VQAModel(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers): super().__init__() self.cnn = models.resnet18(weights="DEFAULT") self.cnn.fc = nn.Identity() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True) self.fc1 = nn.Linear(512 + hidden_dim, 256) self.relu = nn.ReLU() self.fc2 = nn.Linear(256, num_answers) def forward(self, image, question): img_feat = self.cnn(image) q_embed = self.embedding(question) _, (h, _) = self.lstm(q_embed) q_feat = h.squeeze(0) x = self.relu(self.fc1(torch.cat((img_feat, q_feat), dim=1))) return self.fc2(x)