File size: 1,329 Bytes
10e4167
 
 
 
 
 
 
 
 
 
 
 
 
 
bd734a0
 
 
 
 
 
 
 
 
10e4167
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

from huggingface_hub import PyTorchModelHubMixin
import torch
import torch.nn as nn
from transformers import VisionEncoderDecoderModel

class BirdCaptioningModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, num_classes=200):
        super().__init__()
        self.base_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
        self.hidden_size = self.base_model.decoder.config.hidden_size
        self.classifier = nn.Linear(self.hidden_size, num_classes)
    
    def forward(self, pixel_values, input_ids=None, attention_mask=None):
        if input_ids is None:
            batch_size = pixel_values.shape[0]
            input_ids = torch.full(
                (batch_size, 1),
                self.base_model.config.decoder_start_token_id,
                device=pixel_values.device
            )
            attention_mask = torch.ones_like(input_ids)
        
        outputs = self.base_model(
            pixel_values=pixel_values,
            decoder_input_ids=input_ids,
            decoder_attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        hidden_states = outputs.decoder_hidden_states[-1][:, 0, :]
        class_logits = self.classifier(hidden_states)
        return outputs.logits, class_logits