| from huggingface_hub import PyTorchModelHubMixin | |
| import torch | |
| import torch.nn as nn | |
| from transformers import VisionEncoderDecoderModel | |
| class BirdCaptioningModel(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, num_classes=200): | |
| super().__init__() | |
| self.base_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| self.hidden_size = self.base_model.decoder.config.hidden_size | |
| self.classifier = nn.Linear(self.hidden_size, num_classes) | |
| def forward(self, pixel_values, input_ids=None, attention_mask=None): | |
| if input_ids is None: | |
| batch_size = pixel_values.shape[0] | |
| input_ids = torch.full( | |
| (batch_size, 1), | |
| self.base_model.config.decoder_start_token_id, | |
| device=pixel_values.device | |
| ) | |
| attention_mask = torch.ones_like(input_ids) | |
| outputs = self.base_model( | |
| pixel_values=pixel_values, | |
| decoder_input_ids=input_ids, | |
| decoder_attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| return_dict=True | |
| ) | |
| hidden_states = outputs.decoder_hidden_states[-1][:, 0, :] | |
| class_logits = self.classifier(hidden_states) | |
| return outputs.logits, class_logits | |