File size: 7,850 Bytes
9639275
22298fa
 
 
 
 
 
 
 
 
 
 
 
 
 
4cad39d
 
22298fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c033ead
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
print("my_tokenizer.py loaded")
import base64
import logging
import os
import requests
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional

import tiktoken
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizer, AddedToken
from transformers.utils import try_to_load_from_cache

logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen2_5.tiktoken", "ttf": "SimSun.ttf"}


# 特殊标记更新
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
IMG_START = "<image>"
IMG_END = "</image>"
IMG_PAD = "<imagepad>"
REF_START = "<ref>"
REF_END = "</ref>"
BOX_START = "<box>"
BOX_END = "</box>"
QUAD_START = "<quad>"
QUAD_END = "</quad>"

class Qwen2_5_VLTokenizer(PreTrainedTokenizer):
    """Qwen2.5-VL tokenizer, modified from QWenTokenizer."""

    vocab_files_names = VOCAB_FILES_NAMES

    def __init__(
        self,
        vocab_file,
        errors="replace",
        image_start_tag=IMG_START,
        image_end_tag=IMG_END,
        image_pad_tag=IMG_PAD,
        ref_start_tag=REF_START,
        ref_end_tag=REF_END,
        box_start_tag=BOX_START,
        box_end_tag=BOX_END,
        quad_start_tag=QUAD_START,
        quad_end_tag=QUAD_END,
        **kwargs,
    ):
        # 初始化特殊标记
        self.image_start_tag = image_start_tag
        self.image_end_tag = image_end_tag
        self.image_pad_tag = image_pad_tag
        self.ref_start_tag = ref_start_tag
        self.ref_end_tag = ref_end_tag
        self.box_start_tag = box_start_tag
        self.box_end_tag = box_end_tag
        self.quad_start_tag = quad_start_tag
        self.quad_end_tag = quad_end_tag
        
        # 视觉相关特殊标记集合
        self.IMAGE_ST = (
            ref_start_tag, ref_end_tag,
            box_start_tag, box_end_tag,
            quad_start_tag, quad_end_tag,
            image_start_tag, image_end_tag,
            image_pad_tag
        )

        super().__init__(**kwargs)
        self.errors = errors

        # 加载词汇表
        self.mergeable_ranks = self._load_tiktoken_bpe(vocab_file)
        
        # 特殊token处理
        self.special_tokens = {
            token: index
            for index, token in enumerate(
                [IMSTART, IMEND] + list(self.IMAGE_ST), 
                start=len(self.mergeable_ranks)
            )
        }

        # 初始化编码器
        self.tokenizer = tiktoken.Encoding(
            "Qwen2.5",
            pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
            mergeable_ranks=self.mergeable_ranks,
            special_tokens=self.special_tokens,
        )

        # 特殊token ID
        self.im_start_id = self.special_tokens[IMSTART]
        self.im_end_id = self.special_tokens[IMEND]
        self.img_start_id = self.special_tokens[image_start_tag]
        self.img_end_id = self.special_tokens[image_end_tag]
        self.img_pad_id = self.special_tokens[image_pad_tag]

    def _load_tiktoken_bpe(self, tiktoken_bpe_file: str) -> Dict[bytes, int]:
        """加载BPE词汇表"""
        with open(tiktoken_bpe_file, "rb") as f:
            contents = f.read()
        return {
            base64.b64decode(token): int(rank)
            for token, rank in (line.split() for line in contents.splitlines() if line)
        }

    def __len__(self) -> int:
        return self.tokenizer.n_vocab

    def get_vocab(self) -> Dict[bytes, int]:
        return {**self.mergeable_ranks, **self.special_tokens}

    def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
        """Token to id转换"""
        if token in self.special_tokens:
            return self.special_tokens[token]
        if token in self.mergeable_ranks:
            return self.mergeable_ranks[token]
        raise ValueError(f"Unknown token: {token}")

    def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
        """Id to token转换"""
        if index in self.special_tokens.values():
            return list(self.special_tokens.keys())[list(self.special_tokens.values()).index(index)]
        if index in self.mergeable_ranks.values():
            return list(self.mergeable_ranks.keys())[list(self.mergeable_ranks.values()).index(index)]
        raise ValueError(f"Unknown index: {index}")

    def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
        """将token序列转换为字符串"""
        text = ""
        temp = b""
        for t in tokens:
            if isinstance(t, str):
                if temp:
                    text += temp.decode("utf-8", errors=self.errors)
                    temp = b""
                text += t
            elif isinstance(t, bytes):
                temp += t
            else:
                raise TypeError("token should be bytes or str")
        if temp:
            text += temp.decode("utf-8", errors=self.errors)
        return text

    def tokenize(self, text: str, **kwargs) -> List[Union[bytes, str]]:
        """分词处理"""
        text = unicodedata.normalize("NFC", text)
        tokens = [self._convert_id_to_token(i) for i in self.tokenizer.encode(text)]
        return tokens

    def _decode(self, token_ids: List[int], **kwargs) -> str:
        """解码token ids"""
        skip_special_tokens = kwargs.get("skip_special_tokens", False)
        keep_image_special = kwargs.get("keep_image_special", False)
        
        if skip_special_tokens:
            if keep_image_special:
                token_ids = [i for i in token_ids if i < len(self.mergeable_ranks) or 
                           i in [self.img_start_id, self.img_end_id]]
            else:
                token_ids = [i for i in token_ids if i < len(self.mergeable_ranks)]
        
        return self.tokenizer.decode(token_ids, errors=self.errors)

    def to_list_format(self, text: str) -> List[Dict]:
        """将文本转换为列表格式(多模态输入)"""
        text = unicodedata.normalize("NFC", text)
        token_ids = self.tokenizer.encode(text)
        
        def _encode_element(tokens):
            if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
                return [{'image': self._decode(tokens[1:-1])}]
            # 其他视觉元素处理...
            return [{'text': self._decode(tokens)}]
        
        return self._process_visual_tokens(token_ids, _encode_element)

    def from_list_format(self, messages: List[Dict]) -> str:
        """从列表格式构造多模态文本"""
        text = ""
        for msg in messages:
            if 'image' in msg:
                text += f"{self.image_start_tag}{msg['image']}{self.image_end_tag}\n"
            elif 'text' in msg:
                text += msg['text']
            # 其他视觉元素处理...
        return text

    def _process_visual_tokens(self, token_ids, process_func):
        """处理视觉token的通用方法"""
        result = []
        i = 0
        while i < len(token_ids):
            if token_ids[i] == self.img_start_id:
                end = token_ids.index(self.img_end_id, i) if self.img_end_id in token_ids[i:] else len(token_ids)
                result.extend(process_func(token_ids[i:end+1]))
                i = end + 1
            else:
                result.extend(process_func([token_ids[i]]))
                i += 1
        return result

    def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
        """保存词汇表"""
        vocab_file = os.path.join(save_directory, "qwen2_5.tiktoken")
        with open(vocab_file, "w", encoding="utf8") as f:
            for token, rank in self.mergeable_ranks.items():
                f.write(f"{base64.b64encode(token).decode('utf8')} {rank}\n")
        return (vocab_file,)