File size: 829 Bytes
740c342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class CharTokenizer:
    def __init__(self):
        self.stoi = {}
        self.itos = {}

    @property
    def vocab_size(self) -> int:
        return len(self.stoi)

    def fit(self, text: str):
        chars = sorted(set(text))
        self.stoi = {ch: idx for idx, ch in enumerate(chars)}
        self.itos = {idx: ch for ch, idx in self.stoi.items()}
        return self

    def encode(self, text: str):
        return [self.stoi[ch] for ch in text if ch in self.stoi]

    def decode(self, ids):
        return "".join(self.itos.get(int(idx), "") for idx in ids)

    def state_dict(self):
        return {"stoi": self.stoi}

    @classmethod
    def from_state_dict(cls, state):
        tok = cls()
        tok.stoi = dict(state["stoi"])
        tok.itos = {idx: ch for ch, idx in tok.stoi.items()}
        return tok