| import torch | |
| from transformers.modeling_outputs import BaseModelOutput | |
| import torch.nn as nn | |
| class ImageCaptionGenerationWithAttention(nn.Module): | |
| def __init__(self, vit_model, bart_model, tokenizer): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.vit = vit_model | |
| self.bart = bart_model | |
| self.visual_projection = nn.Linear( | |
| vit_model.config.hidden_size, bart_model.config.d_model) | |
| def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None): | |
| vit_outputs = self.vit(pixel_values) | |
| if isinstance(vit_outputs, tuple): | |
| last_hidden_state = vit_outputs[0] | |
| else: | |
| last_hidden_state = vit_outputs.last_hidden_state | |
| visual_features = self.visual_projection(last_hidden_state) | |
| if input_ids is not None: | |
| decoder_outputs = self.bart( | |
| labels=input_ids, | |
| encoder_outputs=BaseModelOutput( | |
| last_hidden_state=visual_features), | |
| return_dict=True | |
| ) | |
| return decoder_outputs | |
| else: | |
| return visual_features | |
| def generate(self, pixel_values, max_length=50, num_beams=5, early_stopping=True): | |
| self.eval() | |
| with torch.no_grad(): | |
| vit_outputs = self.vit(pixel_values) | |
| if isinstance(vit_outputs, tuple): | |
| last_hidden_state = vit_outputs[0] | |
| else: | |
| last_hidden_state = vit_outputs.last_hidden_state | |
| visual_features = self.visual_projection(last_hidden_state) | |
| generated_ids = self.bart.generate( | |
| encoder_outputs=BaseModelOutput( | |
| last_hidden_state=visual_features), | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| early_stopping=early_stopping, | |
| decoder_start_token_id=self.tokenizer.bos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| return_dict_in_generate=False | |
| ) | |
| return generated_ids | |