update generate function
Browse files- modeling_florence2.py +11 -5
modeling_florence2.py
CHANGED
|
@@ -2655,7 +2655,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2655 |
return image_features, image_attention_mask
|
| 2656 |
|
| 2657 |
task_prefix_embeds = inputs_embeds
|
| 2658 |
-
|
| 2659 |
if task_prefix_attention_mask is None:
|
| 2660 |
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
| 2661 |
|
|
@@ -2721,12 +2721,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2721 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2722 |
"A green car parked in front of a yellow building."
|
| 2723 |
```"""
|
|
|
|
|
|
|
| 2724 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2725 |
output_hidden_states = (
|
| 2726 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 2727 |
)
|
| 2728 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2729 |
-
|
| 2730 |
image_features = None
|
| 2731 |
if inputs_embeds is None:
|
| 2732 |
# 1. Extra the input embeddings
|
|
@@ -2736,8 +2738,10 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2736 |
if pixel_values is not None:
|
| 2737 |
# (batch_size, num_image_tokens, hidden_size)
|
| 2738 |
image_features = self._encode_image(pixel_values)
|
| 2739 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds
|
| 2740 |
-
|
|
|
|
|
|
|
| 2741 |
if inputs_embeds is not None:
|
| 2742 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
| 2743 |
outputs = self.language_model(
|
|
@@ -2783,6 +2787,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2783 |
input_ids,
|
| 2784 |
inputs_embeds=None,
|
| 2785 |
pixel_values=None,
|
|
|
|
| 2786 |
**kwargs
|
| 2787 |
):
|
| 2788 |
|
|
@@ -2793,11 +2798,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2793 |
# 2. Merge text and images
|
| 2794 |
if pixel_values is not None:
|
| 2795 |
image_features = self._encode_image(pixel_values)
|
| 2796 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
| 2797 |
|
| 2798 |
return self.language_model.generate(
|
| 2799 |
input_ids=None,
|
| 2800 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 2801 |
**kwargs
|
| 2802 |
)
|
| 2803 |
|
|
|
|
| 2655 |
return image_features, image_attention_mask
|
| 2656 |
|
| 2657 |
task_prefix_embeds = inputs_embeds
|
| 2658 |
+
|
| 2659 |
if task_prefix_attention_mask is None:
|
| 2660 |
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
| 2661 |
|
|
|
|
| 2721 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2722 |
"A green car parked in front of a yellow building."
|
| 2723 |
```"""
|
| 2724 |
+
print("asdasdasdasdasdasdasdasda")
|
| 2725 |
+
|
| 2726 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2727 |
output_hidden_states = (
|
| 2728 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 2729 |
)
|
| 2730 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2731 |
+
print("asdasdasdasdasdasdasdasda")
|
| 2732 |
image_features = None
|
| 2733 |
if inputs_embeds is None:
|
| 2734 |
# 1. Extra the input embeddings
|
|
|
|
| 2738 |
if pixel_values is not None:
|
| 2739 |
# (batch_size, num_image_tokens, hidden_size)
|
| 2740 |
image_features = self._encode_image(pixel_values)
|
| 2741 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
| 2742 |
+
|
| 2743 |
+
print(attention_mask)
|
| 2744 |
+
|
| 2745 |
if inputs_embeds is not None:
|
| 2746 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
| 2747 |
outputs = self.language_model(
|
|
|
|
| 2787 |
input_ids,
|
| 2788 |
inputs_embeds=None,
|
| 2789 |
pixel_values=None,
|
| 2790 |
+
attention_mask=None,
|
| 2791 |
**kwargs
|
| 2792 |
):
|
| 2793 |
|
|
|
|
| 2798 |
# 2. Merge text and images
|
| 2799 |
if pixel_values is not None:
|
| 2800 |
image_features = self._encode_image(pixel_values)
|
| 2801 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
| 2802 |
|
| 2803 |
return self.language_model.generate(
|
| 2804 |
input_ids=None,
|
| 2805 |
inputs_embeds=inputs_embeds,
|
| 2806 |
+
attention_mask=attention_mask,
|
| 2807 |
**kwargs
|
| 2808 |
)
|
| 2809 |
|