anicolson commited on
Commit
1b7f0ab
·
verified ·
1 Parent(s): 1d58753

Update modelling_single.py

Browse files

Fixed issues with attn_implementation and decoder_inputs['past_key_values'].

Files changed (1) hide show
  1. modelling_single.py +2 -1
modelling_single.py CHANGED
@@ -114,6 +114,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
114
  encoder = CvtWithProjectionHead(config=config.encoder)
115
 
116
  # Decoder:
 
117
  if decoder is None:
118
  decoder = transformers.BertLMHeadModel(config=config.decoder)
119
 
@@ -242,7 +243,7 @@ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
242
  'decoder_input_ids': decoder_inputs['input_ids'],
243
  'decoder_token_type_ids': token_type_ids,
244
  'encoder_outputs': encoder_outputs,
245
- 'past_key_values': decoder_inputs['past_key_values'],
246
  'use_cache': use_cache,
247
  }
248
  return input_dict
 
114
  encoder = CvtWithProjectionHead(config=config.encoder)
115
 
116
  # Decoder:
117
+ config.decoder._attn_implementation = 'eager'
118
  if decoder is None:
119
  decoder = transformers.BertLMHeadModel(config=config.decoder)
120
 
 
243
  'decoder_input_ids': decoder_inputs['input_ids'],
244
  'decoder_token_type_ids': token_type_ids,
245
  'encoder_outputs': encoder_outputs,
246
+ 'past_key_values': past_key_values,
247
  'use_cache': use_cache,
248
  }
249
  return input_dict