File size: 4,336 Bytes
9bd9fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Fast-loading Qwen3-VL tokenizer with 131k visual tokens.

Visual tokens live in model.vocab (fast BPE hash-map load) rather than
added_tokens (slow Aho-Corasick build).  A regex pre-split in the Python
wrapper ensures encode/call with visual token text produces single IDs.

Strategy: replace each <|visual token XXXXXX|> with a NUL byte (\x00)
before sending to the Rust backend, then swap the NUL-byte token ID (188)
with the real visual-token ID in the output.
"""

import re
from typing import List, Optional, Union

from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast

_VISUAL_RE = re.compile(r"<\|visual token (\d{6})\|>")
_VISUAL_TOKEN_START_ID = 151674
_PLACEHOLDER_CHAR = "\x00"
_PLACEHOLDER_TOKEN_ID = 188


class Qwen3VLVisualTokenizerFast(Qwen2TokenizerFast):

    # ---------- public encode() ----------
    def encode(self, text, text_pair=None, add_special_tokens=True, **kwargs):
        if isinstance(text, str) and _VISUAL_RE.search(text):
            replaced, vids = _replace_visual(text)
            pair_replaced, pair_vids = None, []
            if text_pair is not None and isinstance(text_pair, str):
                pair_replaced, pair_vids = _replace_visual(text_pair)
            ids = super().encode(
                replaced,
                text_pair=pair_replaced if pair_replaced is not None else text_pair,
                add_special_tokens=add_special_tokens,
                **kwargs,
            )
            _swap_ids(ids, vids + pair_vids)
            return ids
        return super().encode(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)

    # ---------- batch path (powers __call__) ----------
    def _batch_encode_plus(self, batch_text_or_text_pairs, **kwargs):
        has_visual = any(
            _text_has_visual(item) for item in batch_text_or_text_pairs
        )
        if not has_visual:
            return super()._batch_encode_plus(batch_text_or_text_pairs, **kwargs)

        replaced_batch = []
        all_vids: list[list[int]] = []

        for item in batch_text_or_text_pairs:
            if isinstance(item, (tuple, list)):
                text, pair = item[0], (item[1] if len(item) > 1 else None)
            else:
                text, pair = item, None

            vids: list[int] = []
            if isinstance(text, str) and _VISUAL_RE.search(text):
                text, tvids = _replace_visual(text)
                vids.extend(tvids)
            if pair is not None and isinstance(pair, str) and _VISUAL_RE.search(pair):
                pair, pvids = _replace_visual(pair)
                vids.extend(pvids)

            replaced_batch.append((text, pair) if pair is not None else text)
            all_vids.append(vids)

        result = super()._batch_encode_plus(replaced_batch, **kwargs)

        for i, vids in enumerate(all_vids):
            if not vids:
                continue
            ids = result["input_ids"][i]
            tensor_type = None
            if hasattr(ids, "tolist"):
                tensor_type = type(ids)
                device = ids.device if hasattr(ids, "device") else None
                dtype = ids.dtype
                ids = ids.tolist()
            _swap_ids(ids, vids)
            if tensor_type is not None:
                import torch
                t = torch.tensor(ids, dtype=dtype)
                if device is not None:
                    t = t.to(device)
                result["input_ids"][i] = t
            else:
                result["input_ids"][i] = ids

        return result


def _text_has_visual(item) -> bool:
    t = item[0] if isinstance(item, (tuple, list)) else item
    return isinstance(t, str) and _VISUAL_RE.search(t) is not None


def _replace_visual(text: str):
    """Replace visual tokens with NUL bytes, return (new_text, ordered_visual_ids)."""
    vids: list[int] = []

    def _repl(m):
        vids.append(_VISUAL_TOKEN_START_ID + int(m.group(1)))
        return _PLACEHOLDER_CHAR

    new_text = _VISUAL_RE.sub(_repl, text)
    return new_text, vids


def _swap_ids(ids: list, vids: list[int]):
    """In-place replace placeholder token IDs with real visual-token IDs."""
    vi = 0
    for j in range(len(ids)):
        if ids[j] == _PLACEHOLDER_TOKEN_ID and vi < len(vids):
            ids[j] = vids[vi]
            vi += 1