| """ |
| Model implementation for VQA |
| """ |
| import os |
| import json |
| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, AutoModel, AutoConfig, ViTImageProcessor, ViTModel |
|
|
| class VQAModel(nn.Module): |
| """Vision-Language model for Visual Question Answering""" |
| def __init__(self, config, num_answers): |
| super(VQAModel, self).__init__() |
| self.config = config |
| self.num_answers = num_answers |
| |
| |
| self.vision_config = AutoConfig.from_pretrained(config['vision_model']) |
| self.vision_encoder = ViTModel.from_pretrained(config['vision_model']) |
| |
| |
| self.text_config = AutoConfig.from_pretrained(config['text_model']) |
| self.text_encoder = AutoModel.from_pretrained(config['text_model']) |
| |
| |
| self.vision_projection = nn.Linear( |
| self.vision_config.hidden_size, config['hidden_size'] |
| ) |
| self.text_projection = nn.Linear( |
| self.text_config.hidden_size, config['hidden_size'] |
| ) |
| |
| |
| self.fusion = nn.Sequential( |
| nn.Linear(2 * config['hidden_size'], config['hidden_size']), |
| nn.LayerNorm(config['hidden_size']), |
| nn.GELU(), |
| nn.Dropout(config['dropout']) |
| ) |
| |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(config['hidden_size'], config['hidden_size']), |
| nn.LayerNorm(config['hidden_size']), |
| nn.GELU(), |
| nn.Dropout(config['dropout']), |
| nn.Linear(config['hidden_size'], num_answers) |
| ) |
| |
| |
| self.answerable_classifier = nn.Sequential( |
| nn.Linear(config['hidden_size'], config['hidden_size'] // 2), |
| nn.LayerNorm(config['hidden_size'] // 2), |
| nn.GELU(), |
| nn.Dropout(config['dropout']), |
| nn.Linear(config['hidden_size'] // 2, 2) |
| ) |
| |
| def forward(self, image_encodings, question_encodings): |
| """Forward pass of the model""" |
| |
| vision_outputs = self.vision_encoder(**image_encodings) |
| vision_embeds = vision_outputs.last_hidden_state[:, 0] |
| vision_embeds = self.vision_projection(vision_embeds) |
| |
| |
| text_outputs = self.text_encoder(**question_encodings) |
| text_embeds = text_outputs.last_hidden_state[:, 0] |
| text_embeds = self.text_projection(text_embeds) |
| |
| |
| multimodal_features = torch.cat([vision_embeds, text_embeds], dim=1) |
| fused_features = self.fusion(multimodal_features) |
| |
| |
| answer_logits = self.classifier(fused_features) |
| answerable_logits = self.answerable_classifier(fused_features) |
| |
| return { |
| 'answer_logits': answer_logits, |
| 'answerable_logits': answerable_logits, |
| 'fused_features': fused_features |
| } |