RON-110M / code /tokenizer.py
endurasolution's picture
Upload Ron-110M: pretrain + summarizer + tokenizer + code
3b97420 verified
from __future__ import annotations
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers import decoders as _decoders
class TextTokenizer:
"""
Wrapper around tokenizers.Tokenizer that guarantees a ByteLevel decoder
is attached. ByteLevelBPETokenizer saves a JSON without a decoder block,
so reloading via Tokenizer.from_file() yields a tokenizer whose .decode()
returns raw byte-level tokens (Ġ, ä) and replacement chars (�, �)
instead of proper UTF-8 text. We attach the decoder here so decode is
always correct.
"""
def __init__(self, path: str | Path):
self.path = Path(path)
self.tokenizer = Tokenizer.from_file(str(self.path))
# Force a ByteLevel decoder if one is not attached.
try:
current_decoder = self.tokenizer.decoder
except Exception:
current_decoder = None
if current_decoder is None:
self.tokenizer.decoder = _decoders.ByteLevel()
vocab = self.tokenizer.get_vocab()
self.pad_id = vocab.get("<pad>", 0)
self.bos_id = vocab.get("<bos>", 1)
self.eos_id = vocab.get("<eos>", 2)
self.unk_id = vocab.get("<unk>", 3)
self.vocab_size = self.tokenizer.get_vocab_size()
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> list[int]:
ids = self.tokenizer.encode(text).ids
if add_bos:
ids = [self.bos_id] + ids
if add_eos:
ids = ids + [self.eos_id]
return ids
def decode(self, ids: list[int], skip_special_tokens: bool = True) -> str:
if skip_special_tokens:
specials = {self.pad_id, self.bos_id, self.eos_id, self.unk_id}
ids = [int(i) for i in ids if int(i) not in specials]
else:
ids = [int(i) for i in ids]
return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)