File size: 1,983 Bytes
f907cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import re


TOKEN_PATTERN = re.compile(r"\n|[A-Za-z0-9_']+|[^\w\s]")


class WordTokenizer:
    def __init__(self):
        self.special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
        self.stoi = {}
        self.itos = {}

    @property
    def pad_id(self):
        return self.stoi["<pad>"]

    @property
    def bos_id(self):
        return self.stoi["<bos>"]

    @property
    def eos_id(self):
        return self.stoi["<eos>"]

    @property
    def vocab_size(self):
        return len(self.stoi)

    def tokenize(self, text: str):
        return TOKEN_PATTERN.findall(text)

    def fit(self, text: str):
        vocab = self.special_tokens + sorted(set(self.tokenize(text)))
        self.stoi = {token: idx for idx, token in enumerate(vocab)}
        self.itos = {idx: token for token, idx in self.stoi.items()}
        return self

    def encode(self, text: str, add_bos: bool = False, add_eos: bool = False):
        tokens = self.tokenize(text)
        ids = [self.stoi.get(token, self.stoi["<unk>"]) for token in tokens]
        if add_bos:
            ids = [self.bos_id] + ids
        if add_eos:
            ids = ids + [self.eos_id]
        return ids

    def decode(self, ids):
        tokens = []
        for idx in ids:
            token = self.itos.get(int(idx), "<unk>")
            if token in self.special_tokens:
                continue
            tokens.append(token)

        text = ""
        for token in tokens:
            if token == "\n":
                text = text.rstrip() + "\n"
            elif token in {".", ",", "!", "?", ":", ";"}:
                text = text.rstrip() + token + " "
            else:
                text += token + " "
        return text.strip()

    def state_dict(self):
        return {"stoi": self.stoi}

    @classmethod
    def from_state_dict(cls, state):
        tok = cls()
        tok.stoi = dict(state["stoi"])
        tok.itos = {idx: token for token, idx in tok.stoi.items()}
        return tok