File size: 7,850 Bytes
9639275 22298fa 4cad39d 22298fa c033ead | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | print("my_tokenizer.py loaded")
import base64
import logging
import os
import requests
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
import tiktoken
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizer, AddedToken
from transformers.utils import try_to_load_from_cache
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen2_5.tiktoken", "ttf": "SimSun.ttf"}
# 特殊标记更新
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
IMG_START = "<image>"
IMG_END = "</image>"
IMG_PAD = "<imagepad>"
REF_START = "<ref>"
REF_END = "</ref>"
BOX_START = "<box>"
BOX_END = "</box>"
QUAD_START = "<quad>"
QUAD_END = "</quad>"
class Qwen2_5_VLTokenizer(PreTrainedTokenizer):
"""Qwen2.5-VL tokenizer, modified from QWenTokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
image_start_tag=IMG_START,
image_end_tag=IMG_END,
image_pad_tag=IMG_PAD,
ref_start_tag=REF_START,
ref_end_tag=REF_END,
box_start_tag=BOX_START,
box_end_tag=BOX_END,
quad_start_tag=QUAD_START,
quad_end_tag=QUAD_END,
**kwargs,
):
# 初始化特殊标记
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.image_pad_tag = image_pad_tag
self.ref_start_tag = ref_start_tag
self.ref_end_tag = ref_end_tag
self.box_start_tag = box_start_tag
self.box_end_tag = box_end_tag
self.quad_start_tag = quad_start_tag
self.quad_end_tag = quad_end_tag
# 视觉相关特殊标记集合
self.IMAGE_ST = (
ref_start_tag, ref_end_tag,
box_start_tag, box_end_tag,
quad_start_tag, quad_end_tag,
image_start_tag, image_end_tag,
image_pad_tag
)
super().__init__(**kwargs)
self.errors = errors
# 加载词汇表
self.mergeable_ranks = self._load_tiktoken_bpe(vocab_file)
# 特殊token处理
self.special_tokens = {
token: index
for index, token in enumerate(
[IMSTART, IMEND] + list(self.IMAGE_ST),
start=len(self.mergeable_ranks)
)
}
# 初始化编码器
self.tokenizer = tiktoken.Encoding(
"Qwen2.5",
pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
# 特殊token ID
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
self.img_start_id = self.special_tokens[image_start_tag]
self.img_end_id = self.special_tokens[image_end_tag]
self.img_pad_id = self.special_tokens[image_pad_tag]
def _load_tiktoken_bpe(self, tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""加载BPE词汇表"""
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return {**self.mergeable_ranks, **self.special_tokens}
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Token to id转换"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError(f"Unknown token: {token}")
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Id to token转换"""
if index in self.special_tokens.values():
return list(self.special_tokens.keys())[list(self.special_tokens.values()).index(index)]
if index in self.mergeable_ranks.values():
return list(self.mergeable_ranks.keys())[list(self.mergeable_ranks.values()).index(index)]
raise ValueError(f"Unknown index: {index}")
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""将token序列转换为字符串"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should be bytes or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
def tokenize(self, text: str, **kwargs) -> List[Union[bytes, str]]:
"""分词处理"""
text = unicodedata.normalize("NFC", text)
tokens = [self._convert_id_to_token(i) for i in self.tokenizer.encode(text)]
return tokens
def _decode(self, token_ids: List[int], **kwargs) -> str:
"""解码token ids"""
skip_special_tokens = kwargs.get("skip_special_tokens", False)
keep_image_special = kwargs.get("keep_image_special", False)
if skip_special_tokens:
if keep_image_special:
token_ids = [i for i in token_ids if i < len(self.mergeable_ranks) or
i in [self.img_start_id, self.img_end_id]]
else:
token_ids = [i for i in token_ids if i < len(self.mergeable_ranks)]
return self.tokenizer.decode(token_ids, errors=self.errors)
def to_list_format(self, text: str) -> List[Dict]:
"""将文本转换为列表格式(多模态输入)"""
text = unicodedata.normalize("NFC", text)
token_ids = self.tokenizer.encode(text)
def _encode_element(tokens):
if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
return [{'image': self._decode(tokens[1:-1])}]
# 其他视觉元素处理...
return [{'text': self._decode(tokens)}]
return self._process_visual_tokens(token_ids, _encode_element)
def from_list_format(self, messages: List[Dict]) -> str:
"""从列表格式构造多模态文本"""
text = ""
for msg in messages:
if 'image' in msg:
text += f"{self.image_start_tag}{msg['image']}{self.image_end_tag}\n"
elif 'text' in msg:
text += msg['text']
# 其他视觉元素处理...
return text
def _process_visual_tokens(self, token_ids, process_func):
"""处理视觉token的通用方法"""
result = []
i = 0
while i < len(token_ids):
if token_ids[i] == self.img_start_id:
end = token_ids.index(self.img_end_id, i) if self.img_end_id in token_ids[i:] else len(token_ids)
result.extend(process_func(token_ids[i:end+1]))
i = end + 1
else:
result.extend(process_func([token_ids[i]]))
i += 1
return result
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""保存词汇表"""
vocab_file = os.path.join(save_directory, "qwen2_5.tiktoken")
with open(vocab_file, "w", encoding="utf8") as f:
for token, rank in self.mergeable_ranks.items():
f.write(f"{base64.b64encode(token).decode('utf8')} {rank}\n")
return (vocab_file,) |