manu02 commited on
Commit
897c94b
·
verified ·
1 Parent(s): ab97d9b

Fix create_causal_mask compatibility for split branches

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. gpt2_modified.py +12 -8
config.json CHANGED
@@ -35,5 +35,6 @@
35
  "bundled_segmentation_model_name": "bundled_backbones/segmenter_encoder",
36
  "bundled_text_model_name": "bundled_backbones/text_decoder",
37
  "bundled_tokenizer_name": ".",
38
- "segmenter_weights_in_model_state": true
 
39
  }
 
35
  "bundled_segmentation_model_name": "bundled_backbones/segmenter_encoder",
36
  "bundled_text_model_name": "bundled_backbones/text_decoder",
37
  "bundled_tokenizer_name": ".",
38
+ "segmenter_weights_in_model_state": true,
39
+ "visual_projection_type": "mlp4"
40
  }
gpt2_modified.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Optional, Union
 
2
 
3
  import torch
4
  import torch.nn.functional as F
@@ -11,6 +12,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttenti
11
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
12
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
13
 
 
 
14
 
15
  class GPT2AttentionModified(GPT2Attention):
16
  def forward(
@@ -169,14 +172,15 @@ class GPT2ModelModified(GPT2Model):
169
  if attention_mask is not None and attention_mask.ndim < 4:
170
  attention_mask = attention_mask.view(batch_size, -1)
171
 
172
- causal_mask = create_causal_mask(
173
- config=self.config_causal,
174
- inputs_embeds=inputs_embeds,
175
- attention_mask=attention_mask,
176
- cache_position=cache_position,
177
- past_key_values=past_key_values,
178
- position_ids=position_ids,
179
- )
 
180
 
181
  _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
182
  if self.config.add_cross_attention and encoder_hidden_states is not None:
 
1
  from typing import Optional, Union
2
+ import inspect
3
 
4
  import torch
5
  import torch.nn.functional as F
 
12
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
13
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
14
 
15
+ _CREATE_CAUSAL_MASK_EMBEDS_ARG = "inputs_embeds" if "inputs_embeds" in inspect.signature(create_causal_mask).parameters else "input_embeds"
16
+
17
 
18
  class GPT2AttentionModified(GPT2Attention):
19
  def forward(
 
172
  if attention_mask is not None and attention_mask.ndim < 4:
173
  attention_mask = attention_mask.view(batch_size, -1)
174
 
175
+ causal_mask_kwargs = {
176
+ "config": self.config_causal,
177
+ _CREATE_CAUSAL_MASK_EMBEDS_ARG: inputs_embeds,
178
+ "attention_mask": attention_mask,
179
+ "cache_position": cache_position,
180
+ "past_key_values": past_key_values,
181
+ "position_ids": position_ids,
182
+ }
183
+ causal_mask = create_causal_mask(**causal_mask_kwargs)
184
 
185
  _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
186
  if self.config.add_cross_attention and encoder_hidden_states is not None: