qwen_weitiao / my_tokenizer.py
yujingfeng's picture
Update my_tokenizer.py
9639275 verified
print("my_tokenizer.py loaded")
import base64
import logging
import os
import requests
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
import tiktoken
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizer, AddedToken
from transformers.utils import try_to_load_from_cache
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen2_5.tiktoken", "ttf": "SimSun.ttf"}
# 特殊标记更新
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
IMG_START = "<image>"
IMG_END = "</image>"
IMG_PAD = "<imagepad>"
REF_START = "<ref>"
REF_END = "</ref>"
BOX_START = "<box>"
BOX_END = "</box>"
QUAD_START = "<quad>"
QUAD_END = "</quad>"
class Qwen2_5_VLTokenizer(PreTrainedTokenizer):
"""Qwen2.5-VL tokenizer, modified from QWenTokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
image_start_tag=IMG_START,
image_end_tag=IMG_END,
image_pad_tag=IMG_PAD,
ref_start_tag=REF_START,
ref_end_tag=REF_END,
box_start_tag=BOX_START,
box_end_tag=BOX_END,
quad_start_tag=QUAD_START,
quad_end_tag=QUAD_END,
**kwargs,
):
# 初始化特殊标记
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.image_pad_tag = image_pad_tag
self.ref_start_tag = ref_start_tag
self.ref_end_tag = ref_end_tag
self.box_start_tag = box_start_tag
self.box_end_tag = box_end_tag
self.quad_start_tag = quad_start_tag
self.quad_end_tag = quad_end_tag
# 视觉相关特殊标记集合
self.IMAGE_ST = (
ref_start_tag, ref_end_tag,
box_start_tag, box_end_tag,
quad_start_tag, quad_end_tag,
image_start_tag, image_end_tag,
image_pad_tag
)
super().__init__(**kwargs)
self.errors = errors
# 加载词汇表
self.mergeable_ranks = self._load_tiktoken_bpe(vocab_file)
# 特殊token处理
self.special_tokens = {
token: index
for index, token in enumerate(
[IMSTART, IMEND] + list(self.IMAGE_ST),
start=len(self.mergeable_ranks)
)
}
# 初始化编码器
self.tokenizer = tiktoken.Encoding(
"Qwen2.5",
pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
# 特殊token ID
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
self.img_start_id = self.special_tokens[image_start_tag]
self.img_end_id = self.special_tokens[image_end_tag]
self.img_pad_id = self.special_tokens[image_pad_tag]
def _load_tiktoken_bpe(self, tiktoken_bpe_file: str) -> Dict[bytes, int]:
"""加载BPE词汇表"""
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return {**self.mergeable_ranks, **self.special_tokens}
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Token to id转换"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError(f"Unknown token: {token}")
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Id to token转换"""
if index in self.special_tokens.values():
return list(self.special_tokens.keys())[list(self.special_tokens.values()).index(index)]
if index in self.mergeable_ranks.values():
return list(self.mergeable_ranks.keys())[list(self.mergeable_ranks.values()).index(index)]
raise ValueError(f"Unknown index: {index}")
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""将token序列转换为字符串"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should be bytes or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
def tokenize(self, text: str, **kwargs) -> List[Union[bytes, str]]:
"""分词处理"""
text = unicodedata.normalize("NFC", text)
tokens = [self._convert_id_to_token(i) for i in self.tokenizer.encode(text)]
return tokens
def _decode(self, token_ids: List[int], **kwargs) -> str:
"""解码token ids"""
skip_special_tokens = kwargs.get("skip_special_tokens", False)
keep_image_special = kwargs.get("keep_image_special", False)
if skip_special_tokens:
if keep_image_special:
token_ids = [i for i in token_ids if i < len(self.mergeable_ranks) or
i in [self.img_start_id, self.img_end_id]]
else:
token_ids = [i for i in token_ids if i < len(self.mergeable_ranks)]
return self.tokenizer.decode(token_ids, errors=self.errors)
def to_list_format(self, text: str) -> List[Dict]:
"""将文本转换为列表格式(多模态输入)"""
text = unicodedata.normalize("NFC", text)
token_ids = self.tokenizer.encode(text)
def _encode_element(tokens):
if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
return [{'image': self._decode(tokens[1:-1])}]
# 其他视觉元素处理...
return [{'text': self._decode(tokens)}]
return self._process_visual_tokens(token_ids, _encode_element)
def from_list_format(self, messages: List[Dict]) -> str:
"""从列表格式构造多模态文本"""
text = ""
for msg in messages:
if 'image' in msg:
text += f"{self.image_start_tag}{msg['image']}{self.image_end_tag}\n"
elif 'text' in msg:
text += msg['text']
# 其他视觉元素处理...
return text
def _process_visual_tokens(self, token_ids, process_func):
"""处理视觉token的通用方法"""
result = []
i = 0
while i < len(token_ids):
if token_ids[i] == self.img_start_id:
end = token_ids.index(self.img_end_id, i) if self.img_end_id in token_ids[i:] else len(token_ids)
result.extend(process_func(token_ids[i:end+1]))
i = end + 1
else:
result.extend(process_func([token_ids[i]]))
i += 1
return result
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""保存词汇表"""
vocab_file = os.path.join(save_directory, "qwen2_5.tiktoken")
with open(vocab_file, "w", encoding="utf8") as f:
for token, rank in self.mergeable_ranks.items():
f.write(f"{base64.b64encode(token).decode('utf8')} {rank}\n")
return (vocab_file,)