Error while fine-tuning BLIP on own task: found two devices

#3
by rubenjanss - opened

Hi,

I'm fine-tuning this BLIP checkpoint on my own dataset, but I'm running into an error while training. Either I'm not doing something wrong in loading/processing the data or training the model, or there is a bug with the BlipModel?

I'm loading the data and training the model as follows:

model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip-image-captioning-large")
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")

class VCSDatasetProcessor(Dataset):
    def __init__(self, root_dir, df, processor, tokenizer, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text 
        file_name = self.df["image_id"][idx]
        text = self.df["question"][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        #pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        
        # add labels (input_ids) by encoding the text
        labels = self.tokenizer(text, 
                                          padding="max_length",
                                          max_length=self.max_target_length,
                                        truncation=True).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

train_set = VCSDatasetProcessor(root_dir="data/images/", df=train_df, processor=processor, tokenizer=tokenizer)
valid_set = VCSDatasetProcessor(root_dir="data/images/", df=valid_df, processor=processor, tokenizer=tokenizer)

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    fp16=False, 
    output_dir="./blip/",
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
)

from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_set,
    eval_dataset=valid_set,
    data_collator=default_data_collator,
)

trainer.train()

And I get the following error trace:

***** Running training *****
  Num examples = 8886
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 6666
  Number of trainable parameters = 469732924

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [45], line 1
----> 1 trainer.train()

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1534     self.model_wrapped = self.model
   1536 inner_training_loop = find_executable_batch_size(
   1537     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1538 )
-> 1539 return inner_training_loop(
   1540     args=args,
   1541     resume_from_checkpoint=resume_from_checkpoint,
   1542     trial=trial,
   1543     ignore_keys_for_eval=ignore_keys_for_eval,
   1544 )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1787, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1785         tr_loss_step = self.training_step(model, inputs)
   1786 else:
-> 1787     tr_loss_step = self.training_step(model, inputs)
   1789 if (
   1790     args.logging_nan_inf_filter
   1791     and not is_torch_tpu_available()
   1792     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1793 ):
   1794     # if loss is nan or inf simply add the average of previous logged losses
   1795     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2535, in Trainer.training_step(self, model, inputs)
   2532     return loss_mb.reduce_mean().detach().to(self.args.device)
   2534 with self.compute_loss_context_manager():
-> 2535     loss = self.compute_loss(model, inputs)
   2537 if self.args.n_gpu > 1:
   2538     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2567, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2565 else:
   2566     labels = None
-> 2567 outputs = model(**inputs)
   2568 # Save past state if it exists
   2569 # TODO: this needs to be fixed and made cleaner later.
   2570 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip.py:1011, in BlipForConditionalGeneration.forward(self, pixel_values, input_ids, attention_mask, output_attentions, output_hidden_states, labels, return_dict)
   1008 if labels is None:
   1009     labels = input_ids.masked_fill(input_ids == self.decoder_pad_token_id, -100)
-> 1011 outputs = self.text_decoder(
   1012     input_ids=input_ids,
   1013     attention_mask=attention_mask,
   1014     encoder_hidden_states=image_embeds,
   1015     labels=labels,
   1016     return_dict=return_dict,
   1017 )
   1019 if not return_dict:
   1020     outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip_text.py:875, in BlipTextLMHeadModel.forward(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, return_logits, is_decoder, reduction)
    872 if labels is not None:
    873     use_cache = False
--> 875 outputs = self.bert(
    876     input_ids,
    877     attention_mask=attention_mask,
    878     position_ids=position_ids,
    879     head_mask=head_mask,
    880     inputs_embeds=inputs_embeds,
    881     encoder_hidden_states=encoder_hidden_states,
    882     encoder_attention_mask=encoder_attention_mask,
    883     past_key_values=past_key_values,
    884     use_cache=use_cache,
    885     output_attentions=output_attentions,
    886     output_hidden_states=output_hidden_states,
    887     return_dict=return_dict,
    888     is_decoder=is_decoder,
    889 )
    891 sequence_output = outputs[0]
    892 prediction_scores = self.cls(sequence_output)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip_text.py:738, in BlipTextModel.forward(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, encoder_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, is_decoder)
    734     attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)))
    736 # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
    737 # ourselves in which case we just need to make it broadcastable to all heads.
--> 738 extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
    739     attention_mask, input_shape, device, is_decoder
    740 )
    742 # If a 2D or 3D attention mask is provided for the cross-attention
    743 # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
    744 if encoder_hidden_states is not None:

File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip_text.py:645, in BlipTextModel.get_extended_attention_mask(self, attention_mask, input_shape, device, is_decoder)
    634         prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
    635         causal_mask = torch.cat(
    636             [
    637                 torch.ones(
   (...)
    642             axis=-1,
    643         )
--> 645     extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
    646 else:
    647     extended_attention_mask = attention_mask[:, None, None, :]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Does anyone have an idea what could cause this?

Best wishes,

Ruben

hi @rubenjanss
This should be addressed in https://github.com/huggingface/transformers/pull/21021
Meanwhile, you can install transformers from this branch with: pip install git+https://github.com/younesbelkada/transformers.git@blip-train-support
Here is also a colab notebook on how to fine tune BLIP on a custom dataset: https://colab.research.google.com/drive/1lbqiSiA0sDF7JDWPeS0tccrM85LloVha?usp=sharing

Now that https://github.com/huggingface/transformers/pull/21021 has been merged you can install transformers from source and everything should work !

pip install git+https://github.com/huggingface/transformers.git@main

Thank you so much @ybelkada , it's working now!

Thanks! Feel free to close the issue ;)
Don't hesitate to share your model once you have fine-tuned it!

Closing as completed

ybelkada changed discussion status to closed

Sign up or log in to comment