| import json |
| import re |
| import os |
|
|
| class TeraTokenizer: |
| """TERA V2 BPE-lite tokenizer.""" |
|
|
| SPECIAL = ["<pad>", "<unk>", "<bos>", "<eos>"] |
|
|
| def __init__(self): |
| self.word2id = {} |
| self.id2word = {} |
| self.vocab_size = 0 |
| self.pad_id = 0 |
| self.unk_id = 1 |
| self.bos_id = 2 |
| self.eos_id = 3 |
| self.pad_token_id = 0 |
| self.unk_token_id = 1 |
| self.bos_token_id = 2 |
| self.eos_token_id = 3 |
|
|
| |
| @staticmethod |
| def _split(text): |
| return re.findall(r"[A-Za-z]+|[0-9]+|[^\s]", text.strip()) |
|
|
| |
| def train(self, texts, vocab_size=1500): |
| freq = {} |
| for t in texts: |
| for w in self._split(t.lower()): |
| freq[w] = freq.get(w, 0) + 1 |
|
|
| |
| chars = set() |
| for w in freq: |
| for c in w: |
| chars.add(c) |
|
|
| tokens = sorted(chars) |
| token_set = set(tokens) |
|
|
| |
| sorted_words = sorted(freq.items(), key=lambda x: -x[1]) |
| for w, _ in sorted_words: |
| if len(tokens) + len(self.SPECIAL) >= vocab_size: |
| break |
| if w not in token_set: |
| tokens.append(w) |
| token_set.add(w) |
|
|
| |
| all_tokens = list(self.SPECIAL) + tokens |
| self.word2id = {w: i for i, w in enumerate(all_tokens)} |
| self.id2word = {i: w for w, i in self.word2id.items()} |
| self.vocab_size = len(all_tokens) |
| return self |
|
|
| def encode(self, text, add_special=True): |
| ids = [] |
| if add_special: |
| ids.append(self.bos_id) |
| for w in self._split(text.lower()): |
| if w in self.word2id: |
| ids.append(self.word2id[w]) |
| else: |
| |
| for c in w: |
| ids.append(self.word2id.get(c, self.unk_id)) |
| if add_special: |
| ids.append(self.eos_id) |
| return ids |
|
|
| def decode(self, ids): |
| tokens = [] |
| for i in ids: |
| if i in (self.pad_id, self.bos_id, self.eos_id): |
| continue |
| tokens.append(self.id2word.get(i, "<unk>")) |
| return " ".join(tokens) |
|
|
| def tokenize(self, text): |
| return [self.id2word.get(i, "<unk>") for i in self.encode(text, add_special=False)] |
|
|
| def size(self): |
| return self.vocab_size |
|
|
| def save(self, path): |
| data = { |
| "word2id": self.word2id, |
| "id2word": {int(k): v for k, v in self.id2word.items()}, |
| "vocab_size": self.vocab_size, |
| } |
| with open(path, "w") as f: |
| json.dump(data, f) |
|
|
| def load(self, path): |
| with open(path, "r") as f: |
| data = json.load(f) |
| self.word2id = data["word2id"] |
| self.id2word = {int(k): v for k, v in data["id2word"].items()} |
| self.vocab_size = data["vocab_size"] |
| return self |
|
|