codsworth-3.8m / codsworth /tokenizer.py
Jaqshanahan's picture
Initial upload of Codsworth model
b84d85a verified
import os
import json
import re
from functools import cached_property
from typing import Optional, Union
import torch
from torch import Tensor
class Tokenizer:
"""BPE Tokenizer for Codsworth."""
def __init__(
self,
vocab: Optional[dict[str, int]] = None,
merges: Optional[list[str]] = None,
special_tokens: Optional[dict[str, int]] = None,
unk_token: str = "<unk>",
bos_token: str = "<bos>",
eos_token: str = "<eos>",
pad_token: str = "<pad>",
unk_token_id: int = 0,
bos_token_id: int = 1,
eos_token_id: int = 2,
pad_token_id: int = 0,
):
self.unk_token = unk_token
self.bos_token = bos_token
self.eos_token = eos_token
self.pad_token = pad_token
self.unk_token_id = unk_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self._vocab = vocab or {}
self._merges = merges or []
self._special_tokens = special_tokens or {}
if special_tokens is not None:
self._special_tokens = special_tokens
else:
self._special_tokens = {
unk_token: unk_token_id,
bos_token: bos_token_id,
eos_token: eos_token_id,
pad_token: pad_token_id,
}
@cached_property
def vocab_size(self) -> int:
return len(self._vocab)
@cached_property
def eos_id(self) -> int:
return self._special_tokens.get(self.eos_token, self.eos_token_id)
@cached_property
def bos_id(self) -> int:
return self._special_tokens.get(self.bos_token, self.bos_token_id)
@cached_property
def pad_id(self) -> int:
return self._special_tokens.get(self.pad_token, self.pad_token_id)
@cached_property
def unk_id(self) -> int:
return self._special_tokens.get(self.unk_token, self.unk_token_id)
def encode(
self,
text: Union[str, list[str]],
add_special_tokens: bool = True,
add_bos: bool = True,
add_eos: bool = False,
) -> list[int]:
if isinstance(text, str):
text = [text]
token_ids = []
for seq in text:
tokens = self._tokenize(seq)
token_ids.extend(tokens)
if add_special_tokens:
if add_bos:
token_ids = [self.bos_id] + token_ids
if add_eos:
token_ids = token_ids + [self.eos_id]
return token_ids
def decode(
self,
token_ids: Union[list[int], torch.Tensor, Tensor],
remove_special_tokens: bool = True,
) -> str:
if isinstance(token_ids, Tensor):
token_ids = token_ids.tolist()
token_ids = list(token_ids)
if remove_special_tokens:
special_ids = set(self._special_tokens.values())
token_ids = [t for t in token_ids if t not in special_ids]
return self._decode_tokens(token_ids)
def _tokenize(self, text: str) -> list[int]:
return self._bpe_tokenize(text)
def _bpe_tokenize(self, text: str) -> list[int]:
tokens = []
for char in text.encode("utf-8"):
tokens.append(char)
while len(tokens) > 1:
pairs = self._getPairs(tokens)
if not pairs:
break
bigram = min(
pairs,
key=lambda pair: (
self._merges.index(pair)
if pair in self._merges
else float("inf")
),
)
if bigram not in self._merges:
break
tokens = self._merge(tokens, bigram)
return [self._vocab.get(t, self.unk_id) for t in tokens]
def _getPairs(self, tokens: list) -> set:
pairs = set()
for i in range(len(tokens) - 1):
pairs.add((tokens[i], tokens[i + 1]))
return pairs
def _merge(self, tokens: list, bigram: tuple) -> list:
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and tokens[i] == bigram[0] and tokens[i + 1] == bigram[1]:
new_tokens.append(bigram[0] + b"_" + bigram[1])
i += 2
else:
new_tokens.append(tokens[i])
i += 1
return new_tokens
def _decode_tokens(self, token_ids: list[int]) -> str:
tokens = []
for token_id in token_ids:
token = self._reverse_vocab.get(token_id, b"\xff\xfd")
tokens.append(token)
return b"".join(tokens).decode("utf-8", errors="replace")
@cached_property
def _reverse_vocab(self) -> dict[int, bytes]:
return {v: k for k, v in self._vocab.items()}
def __call__(
self,
text: Union[str, list[str]],
return_tensors: Optional[str] = None,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
add_special_tokens: bool = True,
add_bos: bool = True,
add_eos: bool = False,
) -> dict[str, Union[list[int], Tensor]]:
if isinstance(text, str):
text = [text]
encoded = [self.encode(t, add_special_tokens=add_special_tokens, add_bos=add_bos, add_eos=add_eos) for t in text]
if padding or max_length is not None:
max_len = max_length if max_length is not None else max(len(e) for e in encoded)
if padding:
encoded = [
e + [self.pad_id] * (max_len - len(e))
for e in encoded
]
if truncation and max_length is not None:
encoded = [e[:max_length] for e in encoded]
result = {"input_ids": encoded}
if return_tensors == "pt":
result["input_ids"] = torch.tensor(result["input_ids"])
return result
def save(self, path: str) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
data = {
"vocab": {k: v for k, v in self._vocab.items()},
"merges": self._merges,
"special_tokens": self._special_tokens,
"unk_token": self.unk_token,
"bos_token": self.bos_token,
"eos_token": self.eos_token,
"pad_token": self.pad_token,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
@classmethod
def load(cls, path: str) -> "Tokenizer":
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return cls(
vocab=data.get("vocab"),
merges=data.get("merges"),
special_tokens=data.get("special_tokens"),
unk_token=data.get("unk_token", "<unk>"),
bos_token=data.get("bos_token", "<bos>"),
eos_token=data.get("eos_token", "<eos>"),
pad_token=data.get("pad_token", "<pad>"),
)
@classmethod
def train(
cls,
texts: list[str],
vocab_size: int = 50000,
min_frequency: int = 2,
) -> "Tokenizer":
from collections import Counter
tokens = [t.encode("utf-8") for t in texts]
vocab = {}
for byte_val in range(256):
vocab[bytes([byte_val])] = byte_val
merges = []
ids = [list(t) for t in tokens]
while len(vocab) < vocab_size:
pairs = Counter()
for token_ids in ids:
for i in range(len(token_ids) - 1):
pairs[(token_ids[i], token_ids[i + 1])] += 1
if not pairs:
break
best_pair = max(
[p for p, c in pairs.items() if c >= min_frequency],
key=lambda p: (pairs[p], -min(p)),
default=None,
)
if best_pair is None or pairs[best_pair] < min_frequency:
break
merges.append(best_pair)
new_token = best_pair[0] + b"_" + best_pair[1]
new_id = len(vocab)
vocab[new_token] = new_id
new_ids = []
for token_ids in ids:
new_token_ids = []
i = 0
while i < len(token_ids):
if (
i < len(token_ids) - 1
and token_ids[i] == best_pair[0]
and token_ids[i + 1] == best_pair[1]
):
new_token_ids.append(new_id)
i += 2
else:
new_token_ids.append(token_ids[i])
i += 1
new_ids.append(new_token_ids)
ids = new_ids
return cls(
vocab=vocab,
merges=merges,
)
def batch_encode(
self,
texts: list[str],
add_special_tokens: bool = True,
add_bos: bool = True,
add_eos: bool = False,
) -> list[list[int]]:
return [
self.encode(t, add_special_tokens=add_special_tokens, add_bos=add_bos, add_eos=add_eos)
for t in texts
]
def batch_decode(
self,
token_ids: list[list[int]],
remove_special_tokens: bool = True,
) -> list[str]:
return [
self.decode(ids, remove_special_tokens=remove_special_tokens)
for ids in token_ids
]