Update tokenization_orion.py
Browse files- tokenization_orion.py +14 -0
tokenization_orion.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import os
|
| 4 |
from shutil import copyfile
|
| 5 |
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
| 6 |
|
| 7 |
import sentencepiece as spm
|
| 8 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
|
@@ -70,6 +71,7 @@ class OrionTokenizer(PreTrainedTokenizer):
|
|
| 70 |
self.add_eos_token = add_eos_token
|
| 71 |
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 72 |
self.sp_model.Load(vocab_file)
|
|
|
|
| 73 |
super().__init__(
|
| 74 |
bos_token=bos_token,
|
| 75 |
eos_token=eos_token,
|
|
@@ -118,6 +120,8 @@ class OrionTokenizer(PreTrainedTokenizer):
|
|
| 118 |
|
| 119 |
def convert_tokens_to_string(self, tokens):
|
| 120 |
"""Converts a sequence of tokens (string) in a single string."""
|
|
|
|
|
|
|
| 121 |
current_sub_tokens = []
|
| 122 |
out_string = ""
|
| 123 |
prev_is_special = False
|
|
@@ -129,12 +133,22 @@ class OrionTokenizer(PreTrainedTokenizer):
|
|
| 129 |
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 130 |
prev_is_special = True
|
| 131 |
current_sub_tokens = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
else:
|
| 133 |
current_sub_tokens.append(token)
|
| 134 |
prev_is_special = False
|
| 135 |
out_string += self.sp_model.decode(current_sub_tokens)
|
| 136 |
return out_string
|
| 137 |
|
|
|
|
|
|
|
|
|
|
| 138 |
def save_vocabulary(
|
| 139 |
self, save_directory, filename_prefix: Optional[str] = None
|
| 140 |
) -> Tuple[str]:
|
|
|
|
| 3 |
import os
|
| 4 |
from shutil import copyfile
|
| 5 |
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
+
import re
|
| 7 |
|
| 8 |
import sentencepiece as spm
|
| 9 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
|
|
|
| 71 |
self.add_eos_token = add_eos_token
|
| 72 |
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 73 |
self.sp_model.Load(vocab_file)
|
| 74 |
+
|
| 75 |
super().__init__(
|
| 76 |
bos_token=bos_token,
|
| 77 |
eos_token=eos_token,
|
|
|
|
| 120 |
|
| 121 |
def convert_tokens_to_string(self, tokens):
|
| 122 |
"""Converts a sequence of tokens (string) in a single string."""
|
| 123 |
+
zhPattern = re.compile(u'[\u4e00-\u9fa5]+')
|
| 124 |
+
need_convert_punctuation=(",",";","!","?",":","(",")")
|
| 125 |
current_sub_tokens = []
|
| 126 |
out_string = ""
|
| 127 |
prev_is_special = False
|
|
|
|
| 133 |
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 134 |
prev_is_special = True
|
| 135 |
current_sub_tokens = []
|
| 136 |
+
if any([True if punctuation in token else False for punctuation in need_convert_punctuation]):
|
| 137 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 138 |
+
token=self.sp_model.decode(token)
|
| 139 |
+
if zhPattern.search(out_string[-20:]):
|
| 140 |
+
token = self.to_zh_punctuation(token)
|
| 141 |
+
out_string += token
|
| 142 |
+
current_sub_tokens = []
|
| 143 |
else:
|
| 144 |
current_sub_tokens.append(token)
|
| 145 |
prev_is_special = False
|
| 146 |
out_string += self.sp_model.decode(current_sub_tokens)
|
| 147 |
return out_string
|
| 148 |
|
| 149 |
+
def to_zh_punctuation(self, token):
|
| 150 |
+
return token.replace(",",",").replace(";",";").replace("!","!").replace("?","?").replace(":",":").replace("(","(").replace(")",")")
|
| 151 |
+
|
| 152 |
def save_vocabulary(
|
| 153 |
self, save_directory, filename_prefix: Optional[str] = None
|
| 154 |
) -> Tuple[str]:
|