File size: 6,590 Bytes
3cf4fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import abc
import logging
import os
from copy import copy
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import tiktoken
from sentencepiece import SentencePieceProcessor
from tiktoken.load import load_tiktoken_bpe
logger = logging.getLogger(__name__)
@dataclass
class TokenizerArgs:
name: str = "bytes"
path: Optional[str] = None
class Tokenizer(abc.ABC):
@abc.abstractmethod
def encode(self, tokens, add_bos, add_eos):
pass
@abc.abstractmethod
def decode(self, tokens):
pass
@abc.abstractmethod
def get_token_offsets(
self, text: str, tokens: Optional[List[int]] = None
) -> Tuple[List[str], List[int]]:
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
pass
class MockTokenizer(Tokenizer):
n_words: int = 256
def encode(self, tokens, add_bos, add_eos):
return tokens
class ByteTokenizer(Tokenizer):
def __init__(self):
self.bos_id = 256
self.eos_id = 257
self.n_words = 258
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
return tokens
def decode(self, tokens: List[int]):
byte_tokens = bytes([t for t in tokens if t < 256])
return byte_tokens.decode("utf-8", errors="backslashreplace")
def get_token_offsets(
self, text: str, tokens: Optional[List[int]] = None
) -> Tuple[List[str], List[int]]:
if tokens is None:
tokens = self.encode(text)
decoded_chars, offsets = [], []
byte_pos = 0
for token in tokens:
if token < 256:
char = bytes([token]).decode("utf-8", errors="ignore")
if char:
decoded_chars.append(char)
offsets.append(byte_pos)
byte_pos += len(char.encode("utf-8"))
return decoded_chars, offsets
class SentencePieceTokenizer(Tokenizer):
def __init__(self, model_path: str) -> None:
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
logger.info(f"Reloaded SentencePiece model from {model_path}")
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, add_bos: bool, add_eos: bool):
assert type(s) is str
tokens = (
[self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos
)
return tokens
def decode(self, tokens: List[int]):
return self.sp_model.decode(tokens)
def get_token_offsets(
self, text: str, tokens: Optional[List[int]] = None
) -> Tuple[List[str], List[int]]:
pieces = self.sp_model.encode_as_immutable_proto(text).pieces
substrs = [p.surface for p in pieces]
offsets = [p.begin for p in pieces]
return substrs, offsets
DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
DEFAULT_TIKTOKEN_SPECIAL_TOKENS = {
"<|begin_of_text|>": 0,
"<|end_of_text|>": 1,
"<|fim_prefix|>": 2,
"<|fim_middle|>": 3,
"<|fim_end_fill|>": 253,
"<|fim_pad|>": 254,
"<|fim_suffix|>": 255,
}
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
class TikTokenTokenizer(Tokenizer):
def __init__(self, model_path: str) -> None:
mergeable_ranks = load_tiktoken_bpe(model_path)
all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS)
missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values())
for id in missing_ids:
all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id
for name in all_special_tokens_with_ids:
all_special_tokens_with_ids[name] += len(mergeable_ranks)
self.tkt_model = tiktoken.core.Encoding(
name=Path(model_path).stem,
pat_str=DEFAULT_TIKTOKEN_PATTERN,
mergeable_ranks=mergeable_ranks,
special_tokens=all_special_tokens_with_ids,
)
self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>")
self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>")
self.n_words: int = self.tkt_model.n_vocab
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
def encode(self, s: str, add_bos: bool, add_eos: bool):
assert isinstance(s, str)
subs = []
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
return (
[self.bos_id] * add_bos
+ sum(self.tkt_model.encode_ordinary_batch(subs), start=[])
+ [self.eos_id] * add_eos
)
def decode(self, tokens: List[int]):
return self.tkt_model.decode(tokens)
def get_token_offsets(
self, text: str, tokens: Optional[List[int]] = None
) -> Tuple[List[str], List[int]]:
if tokens is not None:
token_bytes = self.tkt_model.decode_tokens_bytes(tokens)
else:
token_bytes = self.tkt_model.decode_tokens_bytes(
self.tkt_model.encode(text, allowed_special="all")
)
text_len, offsets = 0, []
for token in token_bytes:
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])]
return substrs, offsets
def build_tokenizer(name: str, path: Optional[str] = None) -> Tokenizer:
if name == "bytes":
return ByteTokenizer()
elif name == "mock":
return MockTokenizer()
elif name == "sp":
return SentencePieceTokenizer(path)
elif name == "tiktoken":
return TikTokenTokenizer(path)
else:
raise NotImplementedError(f"{name} tokenizer type is not implemented")
|