File size: 1,329 Bytes
10e4167 bd734a0 10e4167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from huggingface_hub import PyTorchModelHubMixin
import torch
import torch.nn as nn
from transformers import VisionEncoderDecoderModel
class BirdCaptioningModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes=200):
super().__init__()
self.base_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
self.hidden_size = self.base_model.decoder.config.hidden_size
self.classifier = nn.Linear(self.hidden_size, num_classes)
def forward(self, pixel_values, input_ids=None, attention_mask=None):
if input_ids is None:
batch_size = pixel_values.shape[0]
input_ids = torch.full(
(batch_size, 1),
self.base_model.config.decoder_start_token_id,
device=pixel_values.device
)
attention_mask = torch.ones_like(input_ids)
outputs = self.base_model(
pixel_values=pixel_values,
decoder_input_ids=input_ids,
decoder_attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
hidden_states = outputs.decoder_hidden_states[-1][:, 0, :]
class_logits = self.classifier(hidden_states)
return outputs.logits, class_logits
|