Update to new HF version
Browse files- modeling_nort5.py +18 -1
modeling_nort5.py
CHANGED
|
@@ -387,7 +387,24 @@ class NorT5Model(NorT5PreTrainedModel):
|
|
| 387 |
self.embedding.word_embedding = value
|
| 388 |
|
| 389 |
def get_encoder(self):
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
def get_decoder(self):
|
| 393 |
return self.get_decoder_output
|
|
|
|
| 387 |
self.embedding.word_embedding = value
|
| 388 |
|
| 389 |
def get_encoder(self):
|
| 390 |
+
class EncoderWrapper:
|
| 391 |
+
def __call__(cls, *args, **kwargs):
|
| 392 |
+
return cls.forward(*args, **kwargs)
|
| 393 |
+
|
| 394 |
+
def forward(
|
| 395 |
+
cls,
|
| 396 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 397 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 398 |
+
output_hidden_states: Optional[bool] = None,
|
| 399 |
+
output_attentions: Optional[bool] = None,
|
| 400 |
+
return_dict: Optional[bool] = None,
|
| 401 |
+
):
|
| 402 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 403 |
+
|
| 404 |
+
return self.get_encoder_output(
|
| 405 |
+
input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
|
| 406 |
+
)
|
| 407 |
+
return EncoderWrapper()
|
| 408 |
|
| 409 |
def get_decoder(self):
|
| 410 |
return self.get_decoder_output
|