mesko-tts / bio_llm /utils /tokenizer.py
mesklintech's picture
Publish BioVoice-TTS sparse energy checkpoint and model card
424c56c verified
import json
import re
from collections import Counter
from pathlib import Path
from typing import Iterable, List, Sequence
class SimpleTokenizer:
"""A small word-and-punctuation tokenizer for CPU-only experiments."""
PAD = "<pad>"
BOS = "<bos>"
EOS = "<eos>"
UNK = "<unk>"
TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
def __init__(self, vocab: List[str]):
self.id_to_token = vocab
self.token_to_id = {token: index for index, token in enumerate(vocab)}
@classmethod
def build(cls, texts: Iterable[str], min_freq: int = 1) -> "SimpleTokenizer":
counter: Counter[str] = Counter()
for text in texts:
counter.update(cls.tokenize(text))
vocab = [cls.PAD, cls.BOS, cls.EOS, cls.UNK]
for token, freq in counter.most_common():
if freq >= min_freq and token not in vocab:
vocab.append(token)
return cls(vocab)
@staticmethod
def tokenize(text: str) -> List[str]:
return SimpleTokenizer.TOKEN_PATTERN.findall(text)
@property
def vocab_size(self) -> int:
return len(self.id_to_token)
@property
def pad_id(self) -> int:
return self.token_to_id[self.PAD]
@property
def bos_id(self) -> int:
return self.token_to_id[self.BOS]
@property
def eos_id(self) -> int:
return self.token_to_id[self.EOS]
@property
def unk_id(self) -> int:
return self.token_to_id[self.UNK]
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
tokens = self.tokenize(text)
ids = [self.token_to_id.get(token, self.unk_id) for token in tokens]
if add_bos:
ids.insert(0, self.bos_id)
if add_eos:
ids.append(self.eos_id)
return ids
def decode(self, token_ids: Iterable[int], skip_special_tokens: bool = True) -> str:
tokens: List[str] = []
specials = {self.PAD, self.BOS, self.EOS, self.UNK}
for token_id in token_ids:
token = self.id_to_token[int(token_id)]
if skip_special_tokens and token in specials:
continue
tokens.append(token)
output = []
for token in tokens:
if output and re.match(r"\w", token) and re.match(r"\w", output[-1][-1]):
output.append(" ")
elif output and token not in {".", ",", "!", "?", ":", ";", "'", '"', ")"} and output[-1] not in {"(", '"'}:
output.append(" ")
output.append(token)
return "".join(output).strip()
def save(self, path: str | Path) -> None:
payload = {"vocab": self.id_to_token}
Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8")
@classmethod
def load(cls, path: str | Path) -> "SimpleTokenizer":
payload = json.loads(Path(path).read_text(encoding="utf-8"))
return cls(payload["vocab"])
class BPETokenizer:
"""A compact BPE tokenizer with greedy longest-match encoding."""
PAD = "<pad>"
BOS = "<bos>"
EOS = "<eos>"
UNK = "<unk>"
END_OF_WORD = "</w>"
TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
def __init__(self, vocab: Sequence[str], merges: Sequence[list[str] | tuple[str, str]]):
self.id_to_token = list(vocab)
self.token_to_id = {token: index for index, token in enumerate(self.id_to_token)}
self.merges = [tuple(pair) for pair in merges]
self.merge_ranks = {pair: index for index, pair in enumerate(self.merges)}
@classmethod
def build(
cls,
texts: Iterable[str],
vocab_size: int = 256,
min_frequency: int = 2,
) -> "BPETokenizer":
words = Counter()
for text in texts:
words.update(cls.TOKEN_PATTERN.findall(text))
word_pieces = {
word: tuple(list(word) + [cls.END_OF_WORD])
for word, frequency in words.items()
if frequency >= 1
}
merges: list[tuple[str, str]] = []
special_tokens = [cls.PAD, cls.BOS, cls.EOS, cls.UNK]
symbol_vocab = {symbol for pieces in word_pieces.values() for symbol in pieces}
while len(symbol_vocab) + len(special_tokens) < vocab_size:
pair_counts: Counter[tuple[str, str]] = Counter()
for word, pieces in word_pieces.items():
frequency = words[word]
for index in range(len(pieces) - 1):
pair_counts[(pieces[index], pieces[index + 1])] += frequency
if not pair_counts:
break
best_pair, best_frequency = pair_counts.most_common(1)[0]
if best_frequency < min_frequency:
break
merged_symbol = "".join(best_pair)
merges.append(best_pair)
updated: dict[str, tuple[str, ...]] = {}
for word, pieces in word_pieces.items():
new_pieces: list[str] = []
index = 0
while index < len(pieces):
if index < len(pieces) - 1 and (pieces[index], pieces[index + 1]) == best_pair:
new_pieces.append(merged_symbol)
index += 2
else:
new_pieces.append(pieces[index])
index += 1
updated[word] = tuple(new_pieces)
word_pieces = updated
symbol_vocab = {symbol for pieces in word_pieces.values() for symbol in pieces}
vocab = special_tokens + sorted(symbol_vocab)
return cls(vocab=vocab, merges=merges)
@staticmethod
def tokenize(text: str) -> List[str]:
return BPETokenizer.TOKEN_PATTERN.findall(text)
@property
def vocab_size(self) -> int:
return len(self.id_to_token)
@property
def pad_id(self) -> int:
return self.token_to_id[self.PAD]
@property
def bos_id(self) -> int:
return self.token_to_id[self.BOS]
@property
def eos_id(self) -> int:
return self.token_to_id[self.EOS]
@property
def unk_id(self) -> int:
return self.token_to_id[self.UNK]
def _apply_merges(self, word: str) -> list[str]:
pieces = list(word) + [self.END_OF_WORD]
if len(pieces) == 1:
return pieces
while True:
candidates = []
for index in range(len(pieces) - 1):
pair = (pieces[index], pieces[index + 1])
if pair in self.merge_ranks:
candidates.append((self.merge_ranks[pair], index, pair))
if not candidates:
break
_, merge_index, pair = min(candidates)
pieces = pieces[:merge_index] + ["".join(pair)] + pieces[merge_index + 2 :]
return pieces
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
ids: list[int] = []
if add_bos:
ids.append(self.bos_id)
for token in self.tokenize(text):
if re.match(r"\w+", token):
pieces = self._apply_merges(token)
else:
pieces = [token + self.END_OF_WORD]
if pieces[0] not in self.token_to_id:
pieces = [token, self.END_OF_WORD]
for piece in pieces:
ids.append(self.token_to_id.get(piece, self.unk_id))
if add_eos:
ids.append(self.eos_id)
return ids
def decode(self, token_ids: Iterable[int], skip_special_tokens: bool = True) -> str:
specials = {self.PAD, self.BOS, self.EOS, self.UNK}
words: list[str] = []
current = ""
for token_id in token_ids:
token = self.id_to_token[int(token_id)]
if skip_special_tokens and token in specials:
continue
if token == self.END_OF_WORD:
if current:
words.append(current)
current = ""
continue
if token.endswith(self.END_OF_WORD):
current += token[: -len(self.END_OF_WORD)]
words.append(current)
current = ""
else:
current += token
if current:
words.append(current)
output: list[str] = []
for word in words:
if not output:
output.append(word)
elif re.match(r"^[^\w\s]+$", word):
output.append(word)
elif re.match(r"^[^\w\s]+$", output[-1]):
output.append(" ")
output.append(word)
else:
output.append(" ")
output.append(word)
return "".join(output).replace(" ", " ").strip()
def save(self, path: str | Path) -> None:
payload = {"type": "bpe", "vocab": self.id_to_token, "merges": [list(pair) for pair in self.merges]}
Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8")
@classmethod
def load(cls, path: str | Path) -> "BPETokenizer":
payload = json.loads(Path(path).read_text(encoding="utf-8"))
return cls(payload["vocab"], payload.get("merges", []))
Tokenizer = SimpleTokenizer | BPETokenizer
def load_tokenizer(path: str | Path) -> Tokenizer:
payload = json.loads(Path(path).read_text(encoding="utf-8"))
if payload.get("type") == "bpe":
return BPETokenizer(payload["vocab"], payload.get("merges", []))
return SimpleTokenizer(payload["vocab"])