| | print("my_tokenizer.py loaded") |
| | import base64 |
| | import logging |
| | import os |
| | import requests |
| | import unicodedata |
| | from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional |
| |
|
| | import tiktoken |
| | import numpy as np |
| | from PIL import Image |
| | from transformers import PreTrainedTokenizer, AddedToken |
| | from transformers.utils import try_to_load_from_cache |
| |
|
| | logger = logging.getLogger(__name__) |
| | VOCAB_FILES_NAMES = {"vocab_file": "qwen2_5.tiktoken", "ttf": "SimSun.ttf"} |
| |
|
| |
|
| | |
| | IMSTART = "<|im_start|>" |
| | IMEND = "<|im_end|>" |
| | IMG_START = "<image>" |
| | IMG_END = "</image>" |
| | IMG_PAD = "<imagepad>" |
| | REF_START = "<ref>" |
| | REF_END = "</ref>" |
| | BOX_START = "<box>" |
| | BOX_END = "</box>" |
| | QUAD_START = "<quad>" |
| | QUAD_END = "</quad>" |
| |
|
| | class Qwen2_5_VLTokenizer(PreTrainedTokenizer): |
| | """Qwen2.5-VL tokenizer, modified from QWenTokenizer.""" |
| |
|
| | vocab_files_names = VOCAB_FILES_NAMES |
| |
|
| | def __init__( |
| | self, |
| | vocab_file, |
| | errors="replace", |
| | image_start_tag=IMG_START, |
| | image_end_tag=IMG_END, |
| | image_pad_tag=IMG_PAD, |
| | ref_start_tag=REF_START, |
| | ref_end_tag=REF_END, |
| | box_start_tag=BOX_START, |
| | box_end_tag=BOX_END, |
| | quad_start_tag=QUAD_START, |
| | quad_end_tag=QUAD_END, |
| | **kwargs, |
| | ): |
| | |
| | self.image_start_tag = image_start_tag |
| | self.image_end_tag = image_end_tag |
| | self.image_pad_tag = image_pad_tag |
| | self.ref_start_tag = ref_start_tag |
| | self.ref_end_tag = ref_end_tag |
| | self.box_start_tag = box_start_tag |
| | self.box_end_tag = box_end_tag |
| | self.quad_start_tag = quad_start_tag |
| | self.quad_end_tag = quad_end_tag |
| | |
| | |
| | self.IMAGE_ST = ( |
| | ref_start_tag, ref_end_tag, |
| | box_start_tag, box_end_tag, |
| | quad_start_tag, quad_end_tag, |
| | image_start_tag, image_end_tag, |
| | image_pad_tag |
| | ) |
| |
|
| | super().__init__(**kwargs) |
| | self.errors = errors |
| |
|
| | |
| | self.mergeable_ranks = self._load_tiktoken_bpe(vocab_file) |
| | |
| | |
| | self.special_tokens = { |
| | token: index |
| | for index, token in enumerate( |
| | [IMSTART, IMEND] + list(self.IMAGE_ST), |
| | start=len(self.mergeable_ranks) |
| | ) |
| | } |
| |
|
| | |
| | self.tokenizer = tiktoken.Encoding( |
| | "Qwen2.5", |
| | pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", |
| | mergeable_ranks=self.mergeable_ranks, |
| | special_tokens=self.special_tokens, |
| | ) |
| |
|
| | |
| | self.im_start_id = self.special_tokens[IMSTART] |
| | self.im_end_id = self.special_tokens[IMEND] |
| | self.img_start_id = self.special_tokens[image_start_tag] |
| | self.img_end_id = self.special_tokens[image_end_tag] |
| | self.img_pad_id = self.special_tokens[image_pad_tag] |
| |
|
| | def _load_tiktoken_bpe(self, tiktoken_bpe_file: str) -> Dict[bytes, int]: |
| | """加载BPE词汇表""" |
| | with open(tiktoken_bpe_file, "rb") as f: |
| | contents = f.read() |
| | return { |
| | base64.b64decode(token): int(rank) |
| | for token, rank in (line.split() for line in contents.splitlines() if line) |
| | } |
| |
|
| | def __len__(self) -> int: |
| | return self.tokenizer.n_vocab |
| |
|
| | def get_vocab(self) -> Dict[bytes, int]: |
| | return {**self.mergeable_ranks, **self.special_tokens} |
| |
|
| | def _convert_token_to_id(self, token: Union[bytes, str]) -> int: |
| | """Token to id转换""" |
| | if token in self.special_tokens: |
| | return self.special_tokens[token] |
| | if token in self.mergeable_ranks: |
| | return self.mergeable_ranks[token] |
| | raise ValueError(f"Unknown token: {token}") |
| |
|
| | def _convert_id_to_token(self, index: int) -> Union[bytes, str]: |
| | """Id to token转换""" |
| | if index in self.special_tokens.values(): |
| | return list(self.special_tokens.keys())[list(self.special_tokens.values()).index(index)] |
| | if index in self.mergeable_ranks.values(): |
| | return list(self.mergeable_ranks.keys())[list(self.mergeable_ranks.values()).index(index)] |
| | raise ValueError(f"Unknown index: {index}") |
| |
|
| | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: |
| | """将token序列转换为字符串""" |
| | text = "" |
| | temp = b"" |
| | for t in tokens: |
| | if isinstance(t, str): |
| | if temp: |
| | text += temp.decode("utf-8", errors=self.errors) |
| | temp = b"" |
| | text += t |
| | elif isinstance(t, bytes): |
| | temp += t |
| | else: |
| | raise TypeError("token should be bytes or str") |
| | if temp: |
| | text += temp.decode("utf-8", errors=self.errors) |
| | return text |
| |
|
| | def tokenize(self, text: str, **kwargs) -> List[Union[bytes, str]]: |
| | """分词处理""" |
| | text = unicodedata.normalize("NFC", text) |
| | tokens = [self._convert_id_to_token(i) for i in self.tokenizer.encode(text)] |
| | return tokens |
| |
|
| | def _decode(self, token_ids: List[int], **kwargs) -> str: |
| | """解码token ids""" |
| | skip_special_tokens = kwargs.get("skip_special_tokens", False) |
| | keep_image_special = kwargs.get("keep_image_special", False) |
| | |
| | if skip_special_tokens: |
| | if keep_image_special: |
| | token_ids = [i for i in token_ids if i < len(self.mergeable_ranks) or |
| | i in [self.img_start_id, self.img_end_id]] |
| | else: |
| | token_ids = [i for i in token_ids if i < len(self.mergeable_ranks)] |
| | |
| | return self.tokenizer.decode(token_ids, errors=self.errors) |
| |
|
| | def to_list_format(self, text: str) -> List[Dict]: |
| | """将文本转换为列表格式(多模态输入)""" |
| | text = unicodedata.normalize("NFC", text) |
| | token_ids = self.tokenizer.encode(text) |
| | |
| | def _encode_element(tokens): |
| | if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id: |
| | return [{'image': self._decode(tokens[1:-1])}] |
| | |
| | return [{'text': self._decode(tokens)}] |
| | |
| | return self._process_visual_tokens(token_ids, _encode_element) |
| |
|
| | def from_list_format(self, messages: List[Dict]) -> str: |
| | """从列表格式构造多模态文本""" |
| | text = "" |
| | for msg in messages: |
| | if 'image' in msg: |
| | text += f"{self.image_start_tag}{msg['image']}{self.image_end_tag}\n" |
| | elif 'text' in msg: |
| | text += msg['text'] |
| | |
| | return text |
| |
|
| | def _process_visual_tokens(self, token_ids, process_func): |
| | """处理视觉token的通用方法""" |
| | result = [] |
| | i = 0 |
| | while i < len(token_ids): |
| | if token_ids[i] == self.img_start_id: |
| | end = token_ids.index(self.img_end_id, i) if self.img_end_id in token_ids[i:] else len(token_ids) |
| | result.extend(process_func(token_ids[i:end+1])) |
| | i = end + 1 |
| | else: |
| | result.extend(process_func([token_ids[i]])) |
| | i += 1 |
| | return result |
| |
|
| | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: |
| | """保存词汇表""" |
| | vocab_file = os.path.join(save_directory, "qwen2_5.tiktoken") |
| | with open(vocab_file, "w", encoding="utf8") as f: |
| | for token, rank in self.mergeable_ranks.items(): |
| | f.write(f"{base64.b64encode(token).decode('utf8')} {rank}\n") |
| | return (vocab_file,) |