from huggingface_hub import PyTorchModelHubMixin import torch import torch.nn as nn from transformers import VisionEncoderDecoderModel class BirdCaptioningModel(nn.Module, PyTorchModelHubMixin): def __init__(self, num_classes=200): super().__init__() self.base_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") self.hidden_size = self.base_model.decoder.config.hidden_size self.classifier = nn.Linear(self.hidden_size, num_classes) def forward(self, pixel_values, input_ids=None, attention_mask=None): if input_ids is None: batch_size = pixel_values.shape[0] input_ids = torch.full( (batch_size, 1), self.base_model.config.decoder_start_token_id, device=pixel_values.device ) attention_mask = torch.ones_like(input_ids) outputs = self.base_model( pixel_values=pixel_values, decoder_input_ids=input_ids, decoder_attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) hidden_states = outputs.decoder_hidden_states[-1][:, 0, :] class_logits = self.classifier(hidden_states) return outputs.logits, class_logits