codemurt commited on
Commit
2b2298a
·
verified ·
1 Parent(s): 87a3819

Upload char_tokenizer.py

Browse files
Files changed (1) hide show
  1. char_tokenizer.py +173 -0
char_tokenizer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copypasted from
3
+ https://huggingface.co/IlyaGusev/ru-word-stress-transformer/blob/main/char_tokenizer.py
4
+ with Apache 2.0 license
5
+ """
6
+
7
+ import os
8
+ from typing import Optional, Tuple, List
9
+ from collections import OrderedDict
10
+
11
+ from torch.utils.data import Dataset
12
+ from transformers import PreTrainedTokenizer, AutoTokenizer
13
+
14
+
15
+ def load_vocab(vocab_file):
16
+ vocab = OrderedDict()
17
+ with open(vocab_file, "r", encoding="utf-8") as reader:
18
+ tokens = reader.readlines()
19
+ for index, token in enumerate(tokens):
20
+ token = token.rstrip("\n")
21
+ vocab[token] = index
22
+ return vocab
23
+
24
+
25
+ class CharTokenizer(PreTrainedTokenizer):
26
+ vocab_files_names = {"vocab_file": "vocab.txt"}
27
+
28
+ def __init__(
29
+ self,
30
+ vocab_file=None,
31
+ pad_token="[pad]",
32
+ unk_token="[unk]",
33
+ bos_token="[bos]",
34
+ eos_token="[eos]",
35
+ cls_token="[cls]",
36
+ sep_token="[sep]",
37
+ mask_token="[mask]",
38
+ space_token="▁",
39
+ do_lower_case=False,
40
+ *args,
41
+ **kwargs
42
+ ):
43
+ self.do_lower_case = do_lower_case
44
+ self.space_token = space_token
45
+
46
+ if not vocab_file or not os.path.isfile(vocab_file):
47
+ self.vocab = OrderedDict()
48
+ self.ids_to_tokens = OrderedDict()
49
+ else:
50
+ self.vocab = load_vocab(vocab_file)
51
+ self.ids_to_tokens = OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
52
+
53
+ super().__init__(
54
+ pad_token=pad_token,
55
+ unk_token=unk_token,
56
+ bos_token=bos_token,
57
+ eos_token=eos_token,
58
+ cls_token=cls_token,
59
+ sep_token=sep_token,
60
+ mask_token=mask_token,
61
+ do_lower_case=do_lower_case,
62
+ **kwargs
63
+ )
64
+ self.do_lower_case = do_lower_case
65
+ self.space_token = space_token
66
+
67
+ if not vocab_file or not os.path.isfile(vocab_file):
68
+ self.vocab = OrderedDict()
69
+ self.ids_to_tokens = OrderedDict()
70
+ else:
71
+ self.vocab = load_vocab(vocab_file)
72
+ self.ids_to_tokens = OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
73
+
74
+ def train(self, file_path):
75
+ vocab = set()
76
+ with open(file_path) as r:
77
+ for line in r:
78
+ word = line.strip()
79
+ if self.do_lower_case:
80
+ word = word.lower()
81
+ vocab |= set(word)
82
+ vocab = list(vocab)
83
+ vocab.sort()
84
+ special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
85
+ vocab = special_tokens + vocab
86
+
87
+ for i, ch in enumerate(vocab):
88
+ self.vocab[ch] = i
89
+ self.ids_to_tokens = vocab
90
+
91
+ @property
92
+ def vocab_size(self):
93
+ return len(self.vocab)
94
+
95
+ def get_vocab(self):
96
+ return self.vocab
97
+
98
+ def _convert_token_to_id(self, token):
99
+ if self.do_lower_case:
100
+ token = token.lower()
101
+ return self.vocab.get(token, self.vocab[self.unk_token])
102
+
103
+ def _convert_id_to_token(self, index):
104
+ return self.ids_to_tokens[index]
105
+
106
+ def prepare_for_tokenization(
107
+ self, text, is_split_into_words: bool = False, spaces=0, **kwargs
108
+ ):
109
+ if spaces:
110
+ pad = self.space_token * spaces
111
+ text = pad + pad.join(text) + pad
112
+ return (text, kwargs)
113
+
114
+ def _tokenize(self, text, spaces=0):
115
+ if self.do_lower_case:
116
+ text = text.lower()
117
+ return list(text)
118
+
119
+ def convert_tokens_to_string(self, tokens):
120
+ return "".join(tokens)
121
+
122
+ def build_inputs_with_special_tokens(
123
+ self,
124
+ token_ids_0: List[int],
125
+ token_ids_1: Optional[List[int]] = None
126
+ ) -> List[int]:
127
+ bos = [self.bos_token_id]
128
+ eos = [self.eos_token_id]
129
+ return bos + token_ids_0 + eos
130
+
131
+ def get_special_tokens_mask(
132
+ self,
133
+ token_ids_0: List[int],
134
+ token_ids_1: Optional[List[int]] = None
135
+ ) -> List[int]:
136
+ return [1] + ([0] * len(token_ids_0)) + [1]
137
+
138
+ def create_token_type_ids_from_sequences(
139
+ self,
140
+ token_ids_0: List[int],
141
+ token_ids_1: Optional[List[int]] = None
142
+ ) -> List[int]:
143
+ return (len(token_ids_0) + 2) * [0]
144
+
145
+ def save_vocabulary(
146
+ self,
147
+ save_directory: str,
148
+ filename_prefix: Optional[str] = None
149
+ ) -> Tuple[str]:
150
+ assert os.path.isdir(save_directory)
151
+ vocab_file = os.path.join(
152
+ save_directory,
153
+ (filename_prefix + "-" if filename_prefix else "") +
154
+ self.vocab_files_names["vocab_file"]
155
+ )
156
+ index = 0
157
+ with open(vocab_file, "w", encoding="utf-8") as writer:
158
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
159
+ assert index == token_index
160
+ writer.write(token + "\n")
161
+ index += 1
162
+ return (vocab_file,)
163
+
164
+ def clean_up_tokenization(self, text, space='▁'):
165
+ res = []
166
+ prev = space
167
+ for c in text:
168
+ if c != prev and c != space:
169
+ res.append(c)
170
+ prev = c
171
+ return ''.join(res)
172
+
173
+ AutoTokenizer.register("char_tokenizer", CharTokenizer)