File size: 3,074 Bytes
02136f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import re
import os

class TeraTokenizer:
    """TERA V2 BPE-lite tokenizer."""

    SPECIAL = ["<pad>", "<unk>", "<bos>", "<eos>"]

    def __init__(self):
        self.word2id = {}
        self.id2word = {}
        self.vocab_size = 0
        self.pad_id = 0
        self.unk_id = 1
        self.bos_id = 2
        self.eos_id = 3
        self.pad_token_id = 0
        self.unk_token_id = 1
        self.bos_token_id = 2
        self.eos_token_id = 3

    # ---- tokenize text into word pieces ----
    @staticmethod
    def _split(text):
        return re.findall(r"[A-Za-z]+|[0-9]+|[^\s]", text.strip())

    # ---- train on list of strings ----
    def train(self, texts, vocab_size=1500):
        freq = {}
        for t in texts:
            for w in self._split(t.lower()):
                freq[w] = freq.get(w, 0) + 1

        # start with characters
        chars = set()
        for w in freq:
            for c in w:
                chars.add(c)

        tokens = sorted(chars)
        token_set = set(tokens)

        # add full words by frequency until we reach vocab_size
        sorted_words = sorted(freq.items(), key=lambda x: -x[1])
        for w, _ in sorted_words:
            if len(tokens) + len(self.SPECIAL) >= vocab_size:
                break
            if w not in token_set:
                tokens.append(w)
                token_set.add(w)

        # build vocab
        all_tokens = list(self.SPECIAL) + tokens
        self.word2id = {w: i for i, w in enumerate(all_tokens)}
        self.id2word = {i: w for w, i in self.word2id.items()}
        self.vocab_size = len(all_tokens)
        return self

    def encode(self, text, add_special=True):
        ids = []
        if add_special:
            ids.append(self.bos_id)
        for w in self._split(text.lower()):
            if w in self.word2id:
                ids.append(self.word2id[w])
            else:
                # character fallback
                for c in w:
                    ids.append(self.word2id.get(c, self.unk_id))
        if add_special:
            ids.append(self.eos_id)
        return ids

    def decode(self, ids):
        tokens = []
        for i in ids:
            if i in (self.pad_id, self.bos_id, self.eos_id):
                continue
            tokens.append(self.id2word.get(i, "<unk>"))
        return " ".join(tokens)

    def tokenize(self, text):
        return [self.id2word.get(i, "<unk>") for i in self.encode(text, add_special=False)]

    def size(self):
        return self.vocab_size

    def save(self, path):
        data = {
            "word2id": self.word2id,
            "id2word": {int(k): v for k, v in self.id2word.items()},
            "vocab_size": self.vocab_size,
        }
        with open(path, "w") as f:
            json.dump(data, f)

    def load(self, path):
        with open(path, "r") as f:
            data = json.load(f)
        self.word2id = data["word2id"]
        self.id2word = {int(k): v for k, v in data["id2word"].items()}
        self.vocab_size = data["vocab_size"]
        return self