clean up
#16
by
pawlowskipawel
- opened
- modeling_florence2.py +15 -7
modeling_florence2.py
CHANGED
|
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2643 |
return x
|
| 2644 |
|
| 2645 |
def _merge_input_ids_with_image_features(
|
| 2646 |
-
self, image_features, inputs_embeds
|
| 2647 |
):
|
| 2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
| 2649 |
device = image_features.device
|
|
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2655 |
return image_features, image_attention_mask
|
| 2656 |
|
| 2657 |
task_prefix_embeds = inputs_embeds
|
| 2658 |
-
|
|
|
|
|
|
|
| 2659 |
|
| 2660 |
-
|
| 2661 |
-
|
| 2662 |
|
| 2663 |
# concat [image embeds, task prefix embeds]
|
| 2664 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
|
@@ -2719,12 +2721,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2719 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2720 |
"A green car parked in front of a yellow building."
|
| 2721 |
```"""
|
|
|
|
|
|
|
| 2722 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2723 |
output_hidden_states = (
|
| 2724 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 2725 |
)
|
| 2726 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2727 |
-
|
| 2728 |
image_features = None
|
| 2729 |
if inputs_embeds is None:
|
| 2730 |
# 1. Extra the input embeddings
|
|
@@ -2735,7 +2739,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2735 |
# (batch_size, num_image_tokens, hidden_size)
|
| 2736 |
image_features = self._encode_image(pixel_values)
|
| 2737 |
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
| 2738 |
-
|
|
|
|
|
|
|
| 2739 |
if inputs_embeds is not None:
|
| 2740 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
| 2741 |
outputs = self.language_model(
|
|
@@ -2781,6 +2787,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2781 |
input_ids,
|
| 2782 |
inputs_embeds=None,
|
| 2783 |
pixel_values=None,
|
|
|
|
| 2784 |
**kwargs
|
| 2785 |
):
|
| 2786 |
|
|
@@ -2791,11 +2798,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2791 |
# 2. Merge text and images
|
| 2792 |
if pixel_values is not None:
|
| 2793 |
image_features = self._encode_image(pixel_values)
|
| 2794 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
| 2795 |
|
| 2796 |
return self.language_model.generate(
|
| 2797 |
input_ids=None,
|
| 2798 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 2799 |
**kwargs
|
| 2800 |
)
|
| 2801 |
|
|
|
|
| 2643 |
return x
|
| 2644 |
|
| 2645 |
def _merge_input_ids_with_image_features(
|
| 2646 |
+
self, image_features, inputs_embeds, task_prefix_attention_mask=None
|
| 2647 |
):
|
| 2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
| 2649 |
device = image_features.device
|
|
|
|
| 2655 |
return image_features, image_attention_mask
|
| 2656 |
|
| 2657 |
task_prefix_embeds = inputs_embeds
|
| 2658 |
+
|
| 2659 |
+
if task_prefix_attention_mask is None:
|
| 2660 |
+
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
| 2661 |
|
| 2662 |
+
if len(task_prefix_attention_mask.shape) == 3:
|
| 2663 |
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
| 2664 |
|
| 2665 |
# concat [image embeds, task prefix embeds]
|
| 2666 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
|
|
|
| 2721 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2722 |
"A green car parked in front of a yellow building."
|
| 2723 |
```"""
|
| 2724 |
+
print("asdasdasdasdasdasdasdasda")
|
| 2725 |
+
|
| 2726 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2727 |
output_hidden_states = (
|
| 2728 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 2729 |
)
|
| 2730 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2731 |
+
print("asdasdasdasdasdasdasdasda")
|
| 2732 |
image_features = None
|
| 2733 |
if inputs_embeds is None:
|
| 2734 |
# 1. Extra the input embeddings
|
|
|
|
| 2739 |
# (batch_size, num_image_tokens, hidden_size)
|
| 2740 |
image_features = self._encode_image(pixel_values)
|
| 2741 |
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
| 2742 |
+
|
| 2743 |
+
print(attention_mask)
|
| 2744 |
+
|
| 2745 |
if inputs_embeds is not None:
|
| 2746 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
| 2747 |
outputs = self.language_model(
|
|
|
|
| 2787 |
input_ids,
|
| 2788 |
inputs_embeds=None,
|
| 2789 |
pixel_values=None,
|
| 2790 |
+
attention_mask=None,
|
| 2791 |
**kwargs
|
| 2792 |
):
|
| 2793 |
|
|
|
|
| 2798 |
# 2. Merge text and images
|
| 2799 |
if pixel_values is not None:
|
| 2800 |
image_features = self._encode_image(pixel_values)
|
| 2801 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
| 2802 |
|
| 2803 |
return self.language_model.generate(
|
| 2804 |
input_ids=None,
|
| 2805 |
inputs_embeds=inputs_embeds,
|
| 2806 |
+
attention_mask=attention_mask,
|
| 2807 |
**kwargs
|
| 2808 |
)
|
| 2809 |
|