Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
69cf4c5
1
Parent(s):
9052370
Attention mask in decoder
Browse files
qarac/models/QaracDecoderModel.py
CHANGED
|
@@ -117,11 +117,11 @@ class QaracDecoderModel(transformers.RobertaModel,
|
|
| 117 |
|
| 118 |
"""
|
| 119 |
(v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
|
| 120 |
-
|
| 121 |
return self.decoder_head(torch.unsqueeze(v,1),
|
| 122 |
-
self.decoder_base(
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
def prepare_inputs_for_generation(self,
|
| 127 |
input_ids,
|
|
|
|
| 117 |
|
| 118 |
"""
|
| 119 |
(v,s) = (kwargs['vector'],inputs) if 'vector' in kwargs else inputs
|
| 120 |
+
(seed,attention_mask) = (s['input_ids'],s['attention_mask']) if 'attention_mask' in s else (s,None)
|
| 121 |
return self.decoder_head(torch.unsqueeze(v,1),
|
| 122 |
+
self.decoder_base(seed,
|
| 123 |
+
attention_mask=attention_mask,
|
| 124 |
+
use_cache='vector' in kwargs).last_hidden_state)
|
| 125 |
|
| 126 |
def prepare_inputs_for_generation(self,
|
| 127 |
input_ids,
|