PeteBleackley commited on
Commit
69cf4c5
·
1 Parent(s): 9052370

Attention mask in decoder

Browse files
Files changed (1) hide show
  1. qarac/models/QaracDecoderModel.py +4 -4
qarac/models/QaracDecoderModel.py CHANGED
@@ -117,11 +117,11 @@ class QaracDecoderModel(transformers.RobertaModel,
117
 
118
  """
119
  (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
120
-
121
  return self.decoder_head(torch.unsqueeze(v,1),
122
- self.decoder_base(s,
123
- use_cache='vector' in kwargs).last_hidden_state,
124
- training = kwargs.get('training',False))
125
 
126
  def prepare_inputs_for_generation(self,
127
  input_ids,
 
117
 
118
  """
119
  (v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
120
+ (seed,attention_mask) = (s['input_ids'],s['attention_mask']) if 'attention_mask' in s else (s,None)
121
  return self.decoder_head(torch.unsqueeze(v,1),
122
+ self.decoder_base(seed,
123
+ attention_mask=attention_mask,
124
+ use_cache='vector' in kwargs).last_hidden_state)
125
 
126
  def prepare_inputs_for_generation(self,
127
  input_ids,