INVERTO commited on
Commit
bd734a0
·
verified ·
1 Parent(s): 490a0e4

Upload trained bird captioning model, tokenizer, image processor, species mapping, and captions

Browse files
Files changed (1) hide show
  1. model.py +9 -0
model.py CHANGED
@@ -12,6 +12,15 @@ class BirdCaptioningModel(nn.Module, PyTorchModelHubMixin):
12
  self.classifier = nn.Linear(self.hidden_size, num_classes)
13
 
14
  def forward(self, pixel_values, input_ids=None, attention_mask=None):
 
 
 
 
 
 
 
 
 
15
  outputs = self.base_model(
16
  pixel_values=pixel_values,
17
  decoder_input_ids=input_ids,
 
12
  self.classifier = nn.Linear(self.hidden_size, num_classes)
13
 
14
  def forward(self, pixel_values, input_ids=None, attention_mask=None):
15
+ if input_ids is None:
16
+ batch_size = pixel_values.shape[0]
17
+ input_ids = torch.full(
18
+ (batch_size, 1),
19
+ self.base_model.config.decoder_start_token_id,
20
+ device=pixel_values.device
21
+ )
22
+ attention_mask = torch.ones_like(input_ids)
23
+
24
  outputs = self.base_model(
25
  pixel_values=pixel_values,
26
  decoder_input_ids=input_ids,