3scale
/

File size: 2,895 Bytes
06b85d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Custom tokenizer for VoxCPM2 that splits multi-character Chinese tokens.

VoxCPM2 was trained with ``mask_multichar_chinese_tokens`` which splits
multi-character Chinese tokens (e.g. "你好" -> ["你", "好"]) into individual
character IDs before embedding.  The base LlamaTokenizerFast produces
multi-character Chinese tokens that the model has never seen during training,
yielding garbled Chinese audio output in downstream inference frameworks.

This module provides ``VoxCPM2Tokenizer`` which transparently applies the
character splitting inside ``encode()`` and ``__call__()``, so any downstream
consumer (vLLM, vLLM-Omni, Nano-vLLM, etc.) gets correct single-character
IDs without code changes.
"""

from transformers import LlamaTokenizerFast


class VoxCPM2Tokenizer(LlamaTokenizerFast):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._split_map = self._build_split_map()

    def _build_split_map(self) -> dict[int, list[int]]:
        vocab = self.get_vocab()
        split_map: dict[int, list[int]] = {}
        for token, tid in vocab.items():
            clean = token.replace("\u2581", "")
            if len(clean) >= 2 and all(self._is_cjk(c) for c in clean):
                char_ids = self.convert_tokens_to_ids(list(clean))
                if all(c != self.unk_token_id for c in char_ids):
                    split_map[tid] = char_ids
        return split_map

    @staticmethod
    def _is_cjk(c: str) -> bool:
        return (
            "\u4e00" <= c <= "\u9fff"
            or "\u3400" <= c <= "\u4dbf"
            or "\uf900" <= c <= "\ufaff"
            or "\U00020000" <= c <= "\U0002a6df"
        )

    def _expand_ids(self, ids: list[int]) -> list[int]:
        result: list[int] = []
        for tid in ids:
            expansion = self._split_map.get(tid)
            if expansion is not None:
                result.extend(expansion)
            else:
                result.append(tid)
        return result

    def encode(self, text, *args, **kwargs):
        ids = super().encode(text, *args, **kwargs)
        return self._expand_ids(ids)

    def __call__(self, text, *args, **kwargs):
        result = super().__call__(text, *args, **kwargs)
        if hasattr(result, "input_ids"):
            ids = result["input_ids"]
            if isinstance(ids, list) and ids and isinstance(ids[0], list):
                result["input_ids"] = [self._expand_ids(x) for x in ids]
                if "attention_mask" in result:
                    result["attention_mask"] = [
                        [1] * len(x) for x in result["input_ids"]
                    ]
            elif isinstance(ids, list):
                result["input_ids"] = self._expand_ids(ids)
                if "attention_mask" in result:
                    result["attention_mask"] = [1] * len(result["input_ids"])
        return result