fix tokenization tensor
Browse files- tokenization_gptpangu.py +4 -0
tokenization_gptpangu.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 2 |
|
|
|
|
| 3 |
import sentencepiece
|
| 4 |
import jieba
|
| 5 |
|
|
@@ -37,6 +38,9 @@ class GPTPanguTokenizer(PreTrainedTokenizer):
|
|
| 37 |
return self.decode(ids)
|
| 38 |
|
| 39 |
def decode(self, tokens, **kwargs):
|
|
|
|
|
|
|
|
|
|
| 40 |
text = self.sp.decode(tokens)
|
| 41 |
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
|
| 42 |
return text
|
|
|
|
| 1 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 2 |
|
| 3 |
+
import torch
|
| 4 |
import sentencepiece
|
| 5 |
import jieba
|
| 6 |
|
|
|
|
| 38 |
return self.decode(ids)
|
| 39 |
|
| 40 |
def decode(self, tokens, **kwargs):
|
| 41 |
+
if isinstance(tokens, torch.Tensor):
|
| 42 |
+
tokens = tokens.tolist()
|
| 43 |
+
|
| 44 |
text = self.sp.decode(tokens)
|
| 45 |
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
|
| 46 |
return text
|