vqa_project / models /vqa_model.py
PRUTHVIn's picture
Upload folder using huggingface_hub
1e5f3d4 verified
raw
history blame contribute delete
837 Bytes
import torch
import torch.nn as nn
import torchvision.models as models
class VQAModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_answers):
super().__init__()
self.cnn = models.resnet18(weights="DEFAULT")
self.cnn.fc = nn.Identity()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
self.fc1 = nn.Linear(512 + hidden_dim, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, num_answers)
def forward(self, image, question):
img_feat = self.cnn(image)
q_embed = self.embedding(question)
_, (h, _) = self.lstm(q_embed)
q_feat = h.squeeze(0)
x = self.relu(self.fc1(torch.cat((img_feat, q_feat), dim=1)))
return self.fc2(x)