Chengxu Zhuang commited on
Commit ·
272b3c6
1
Parent(s): ee546e1
minor fix for causal mask
Browse files- modeling_flamingo.py +14 -3
modeling_flamingo.py
CHANGED
|
@@ -14,6 +14,12 @@ import transformers.models.opt.modeling_opt as modeling_opt
|
|
| 14 |
from transformers.models.opt.modeling_opt\
|
| 15 |
import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
|
| 16 |
from transformers import ViTModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
|
| 18 |
from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
|
| 19 |
from .configuration_flamingo import FlamingoConfig
|
|
@@ -232,9 +238,14 @@ class OPTDecoder(modeling_opt.OPTDecoder):
|
|
| 232 |
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
| 233 |
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
| 234 |
|
| 235 |
-
|
| 236 |
-
attention_mask
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
if self.project_in is not None:
|
| 240 |
inputs_embeds = self.project_in(inputs_embeds)
|
|
|
|
| 14 |
from transformers.models.opt.modeling_opt\
|
| 15 |
import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
|
| 16 |
from transformers import ViTModel
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from transformers.models.opt.modeling_opt import _prepare_4d_causal_attention_mask
|
| 20 |
+
except:
|
| 21 |
+
_prepare_4d_causal_attention_mask = None
|
| 22 |
+
|
| 23 |
from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
|
| 24 |
from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
|
| 25 |
from .configuration_flamingo import FlamingoConfig
|
|
|
|
| 238 |
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
|
| 239 |
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
| 240 |
|
| 241 |
+
if _prepare_4d_causal_attention_mask is None:
|
| 242 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
| 243 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 247 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 248 |
+
)
|
| 249 |
|
| 250 |
if self.project_in is not None:
|
| 251 |
inputs_embeds = self.project_in(inputs_embeds)
|