Update modelling_single.py
Browse filesFixed issues with attn_implementation and decoder_inputs['past_key_values'].
- modelling_single.py +2 -1
modelling_single.py
CHANGED
|
@@ -114,6 +114,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
| 114 |
encoder = CvtWithProjectionHead(config=config.encoder)
|
| 115 |
|
| 116 |
# Decoder:
|
|
|
|
| 117 |
if decoder is None:
|
| 118 |
decoder = transformers.BertLMHeadModel(config=config.decoder)
|
| 119 |
|
|
@@ -242,7 +243,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
| 242 |
'decoder_input_ids': decoder_inputs['input_ids'],
|
| 243 |
'decoder_token_type_ids': token_type_ids,
|
| 244 |
'encoder_outputs': encoder_outputs,
|
| 245 |
-
'past_key_values':
|
| 246 |
'use_cache': use_cache,
|
| 247 |
}
|
| 248 |
return input_dict
|
|
|
|
| 114 |
encoder = CvtWithProjectionHead(config=config.encoder)
|
| 115 |
|
| 116 |
# Decoder:
|
| 117 |
+
config.decoder._attn_implementation = 'eager'
|
| 118 |
if decoder is None:
|
| 119 |
decoder = transformers.BertLMHeadModel(config=config.decoder)
|
| 120 |
|
|
|
|
| 243 |
'decoder_input_ids': decoder_inputs['input_ids'],
|
| 244 |
'decoder_token_type_ids': token_type_ids,
|
| 245 |
'encoder_outputs': encoder_outputs,
|
| 246 |
+
'past_key_values': past_key_values,
|
| 247 |
'use_cache': use_cache,
|
| 248 |
}
|
| 249 |
return input_dict
|