import re TOKEN_PATTERN = re.compile(r"\n|[A-Za-z0-9_']+|[^\w\s]") class WordTokenizer: def __init__(self): self.special_tokens = ["", "", "", ""] self.stoi = {} self.itos = {} @property def pad_id(self): return self.stoi[""] @property def bos_id(self): return self.stoi[""] @property def eos_id(self): return self.stoi[""] @property def vocab_size(self): return len(self.stoi) def tokenize(self, text: str): return TOKEN_PATTERN.findall(text) def fit(self, text: str): vocab = self.special_tokens + sorted(set(self.tokenize(text))) self.stoi = {token: idx for idx, token in enumerate(vocab)} self.itos = {idx: token for token, idx in self.stoi.items()} return self def encode(self, text: str, add_bos: bool = False, add_eos: bool = False): tokens = self.tokenize(text) ids = [self.stoi.get(token, self.stoi[""]) for token in tokens] if add_bos: ids = [self.bos_id] + ids if add_eos: ids = ids + [self.eos_id] return ids def decode(self, ids): tokens = [] for idx in ids: token = self.itos.get(int(idx), "") if token in self.special_tokens: continue tokens.append(token) text = "" for token in tokens: if token == "\n": text = text.rstrip() + "\n" elif token in {".", ",", "!", "?", ":", ";"}: text = text.rstrip() + token + " " else: text += token + " " return text.strip() def state_dict(self): return {"stoi": self.stoi} @classmethod def from_state_dict(cls, state): tok = cls() tok.stoi = dict(state["stoi"]) tok.itos = {idx: token for token, idx in tok.stoi.items()} return tok