Fix create_causal_mask compatibility for split branches
Browse files- config.json +2 -1
- gpt2_modified.py +12 -8
config.json
CHANGED
|
@@ -35,5 +35,6 @@
|
|
| 35 |
"bundled_segmentation_model_name": "bundled_backbones/segmenter_encoder",
|
| 36 |
"bundled_text_model_name": "bundled_backbones/text_decoder",
|
| 37 |
"bundled_tokenizer_name": ".",
|
| 38 |
-
"segmenter_weights_in_model_state": true
|
|
|
|
| 39 |
}
|
|
|
|
| 35 |
"bundled_segmentation_model_name": "bundled_backbones/segmenter_encoder",
|
| 36 |
"bundled_text_model_name": "bundled_backbones/text_decoder",
|
| 37 |
"bundled_tokenizer_name": ".",
|
| 38 |
+
"segmenter_weights_in_model_state": true,
|
| 39 |
+
"visual_projection_type": "mlp4"
|
| 40 |
}
|
gpt2_modified.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from typing import Optional, Union
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
|
@@ -11,6 +12,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttenti
|
|
| 11 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 12 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class GPT2AttentionModified(GPT2Attention):
|
| 16 |
def forward(
|
|
@@ -169,14 +172,15 @@ class GPT2ModelModified(GPT2Model):
|
|
| 169 |
if attention_mask is not None and attention_mask.ndim < 4:
|
| 170 |
attention_mask = attention_mask.view(batch_size, -1)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
config
|
| 174 |
-
inputs_embeds
|
| 175 |
-
attention_mask
|
| 176 |
-
cache_position
|
| 177 |
-
past_key_values
|
| 178 |
-
position_ids
|
| 179 |
-
|
|
|
|
| 180 |
|
| 181 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
| 182 |
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
|
|
|
| 1 |
from typing import Optional, Union
|
| 2 |
+
import inspect
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 12 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 13 |
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
|
| 14 |
|
| 15 |
+
_CREATE_CAUSAL_MASK_EMBEDS_ARG = "inputs_embeds" if "inputs_embeds" in inspect.signature(create_causal_mask).parameters else "input_embeds"
|
| 16 |
+
|
| 17 |
|
| 18 |
class GPT2AttentionModified(GPT2Attention):
|
| 19 |
def forward(
|
|
|
|
| 172 |
if attention_mask is not None and attention_mask.ndim < 4:
|
| 173 |
attention_mask = attention_mask.view(batch_size, -1)
|
| 174 |
|
| 175 |
+
causal_mask_kwargs = {
|
| 176 |
+
"config": self.config_causal,
|
| 177 |
+
_CREATE_CAUSAL_MASK_EMBEDS_ARG: inputs_embeds,
|
| 178 |
+
"attention_mask": attention_mask,
|
| 179 |
+
"cache_position": cache_position,
|
| 180 |
+
"past_key_values": past_key_values,
|
| 181 |
+
"position_ids": position_ids,
|
| 182 |
+
}
|
| 183 |
+
causal_mask = create_causal_mask(**causal_mask_kwargs)
|
| 184 |
|
| 185 |
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
| 186 |
if self.config.add_cross_attention and encoder_hidden_states is not None:
|