Error while fine-tuning BLIP on own task: found two devices
Hi,
I'm fine-tuning this BLIP checkpoint on my own dataset, but I'm running into an error while training. Either I'm not doing something wrong in loading/processing the data or training the model, or there is a bug with the BlipModel?
I'm loading the data and training the model as follows:
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip-image-captioning-large")
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
class VCSDatasetProcessor(Dataset):
def __init__(self, root_dir, df, processor, tokenizer, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.tokenizer = tokenizer
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df["image_id"][idx]
text = self.df["question"][idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
#pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.tokenizer(text,
padding="max_length",
max_length=self.max_target_length,
truncation=True).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
train_set = VCSDatasetProcessor(root_dir="data/images/", df=train_df, processor=processor, tokenizer=tokenizer)
valid_set = VCSDatasetProcessor(root_dir="data/images/", df=valid_df, processor=processor, tokenizer=tokenizer)
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="steps",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
fp16=False,
output_dir="./blip/",
logging_steps=2,
save_steps=1000,
eval_steps=200,
)
from transformers import default_data_collator
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_set,
eval_dataset=valid_set,
data_collator=default_data_collator,
)
trainer.train()
And I get the following error trace:
***** Running training *****
Num examples = 8886
Num Epochs = 3
Instantaneous batch size per device = 4
Total train batch size (w. parallel, distributed & accumulation) = 4
Gradient Accumulation steps = 1
Total optimization steps = 6666
Number of trainable parameters = 469732924
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In [45], line 1
----> 1 trainer.train()
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1534 self.model_wrapped = self.model
1536 inner_training_loop = find_executable_batch_size(
1537 self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
1538 )
-> 1539 return inner_training_loop(
1540 args=args,
1541 resume_from_checkpoint=resume_from_checkpoint,
1542 trial=trial,
1543 ignore_keys_for_eval=ignore_keys_for_eval,
1544 )
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1787, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1785 tr_loss_step = self.training_step(model, inputs)
1786 else:
-> 1787 tr_loss_step = self.training_step(model, inputs)
1789 if (
1790 args.logging_nan_inf_filter
1791 and not is_torch_tpu_available()
1792 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
1793 ):
1794 # if loss is nan or inf simply add the average of previous logged losses
1795 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2535, in Trainer.training_step(self, model, inputs)
2532 return loss_mb.reduce_mean().detach().to(self.args.device)
2534 with self.compute_loss_context_manager():
-> 2535 loss = self.compute_loss(model, inputs)
2537 if self.args.n_gpu > 1:
2538 loss = loss.mean() # mean() to average on multi-gpu parallel training
File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2567, in Trainer.compute_loss(self, model, inputs, return_outputs)
2565 else:
2566 labels = None
-> 2567 outputs = model(**inputs)
2568 # Save past state if it exists
2569 # TODO: this needs to be fixed and made cleaner later.
2570 if self.args.past_index >= 0:
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip.py:1011, in BlipForConditionalGeneration.forward(self, pixel_values, input_ids, attention_mask, output_attentions, output_hidden_states, labels, return_dict)
1008 if labels is None:
1009 labels = input_ids.masked_fill(input_ids == self.decoder_pad_token_id, -100)
-> 1011 outputs = self.text_decoder(
1012 input_ids=input_ids,
1013 attention_mask=attention_mask,
1014 encoder_hidden_states=image_embeds,
1015 labels=labels,
1016 return_dict=return_dict,
1017 )
1019 if not return_dict:
1020 outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip_text.py:875, in BlipTextLMHeadModel.forward(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, return_logits, is_decoder, reduction)
872 if labels is not None:
873 use_cache = False
--> 875 outputs = self.bert(
876 input_ids,
877 attention_mask=attention_mask,
878 position_ids=position_ids,
879 head_mask=head_mask,
880 inputs_embeds=inputs_embeds,
881 encoder_hidden_states=encoder_hidden_states,
882 encoder_attention_mask=encoder_attention_mask,
883 past_key_values=past_key_values,
884 use_cache=use_cache,
885 output_attentions=output_attentions,
886 output_hidden_states=output_hidden_states,
887 return_dict=return_dict,
888 is_decoder=is_decoder,
889 )
891 sequence_output = outputs[0]
892 prediction_scores = self.cls(sequence_output)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip_text.py:738, in BlipTextModel.forward(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, encoder_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, is_decoder)
734 attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)))
736 # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
737 # ourselves in which case we just need to make it broadcastable to all heads.
--> 738 extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
739 attention_mask, input_shape, device, is_decoder
740 )
742 # If a 2D or 3D attention mask is provided for the cross-attention
743 # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
744 if encoder_hidden_states is not None:
File /opt/conda/lib/python3.10/site-packages/transformers/models/blip/modeling_blip_text.py:645, in BlipTextModel.get_extended_attention_mask(self, attention_mask, input_shape, device, is_decoder)
634 prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
635 causal_mask = torch.cat(
636 [
637 torch.ones(
(...)
642 axis=-1,
643 )
--> 645 extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
646 else:
647 extended_attention_mask = attention_mask[:, None, None, :]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Does anyone have an idea what could cause this?
Best wishes,
Ruben
hi @rubenjanss
This should be addressed in https://github.com/huggingface/transformers/pull/21021
Meanwhile, you can install transformers from this branch with: pip install git+https://github.com/younesbelkada/transformers.git@blip-train-support
Here is also a colab notebook on how to fine tune BLIP on a custom dataset: https://colab.research.google.com/drive/1lbqiSiA0sDF7JDWPeS0tccrM85LloVha?usp=sharing
Now that https://github.com/huggingface/transformers/pull/21021 has been merged you can install transformers from source and everything should work !
pip install git+https://github.com/huggingface/transformers.git@main
Thank you so much @ybelkada , it's working now!
Thanks! Feel free to close the issue ;)
Don't hesitate to share your model once you have fine-tuned it!
Closing as completed