FenomAI
/

VoxCPM2 / tokenization_voxcpm2.py
FenomAI's picture
Duplicate from openbmb/VoxCPM2
2f1090b
"""Custom tokenizer for VoxCPM2 that splits multi-character Chinese tokens.
VoxCPM2 was trained with ``mask_multichar_chinese_tokens`` which splits
multi-character Chinese tokens (e.g. "你好" -> ["你", "好"]) into individual
character IDs before embedding. The base LlamaTokenizerFast produces
multi-character Chinese tokens that the model has never seen during training,
yielding garbled Chinese audio output in downstream inference frameworks.
This module provides ``VoxCPM2Tokenizer`` which transparently applies the
character splitting inside ``encode()`` and ``__call__()``, so any downstream
consumer (vLLM, vLLM-Omni, Nano-vLLM, etc.) gets correct single-character
IDs without code changes.
"""
from transformers import LlamaTokenizerFast
class VoxCPM2Tokenizer(LlamaTokenizerFast):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._split_map = self._build_split_map()
def _build_split_map(self) -> dict[int, list[int]]:
vocab = self.get_vocab()
split_map: dict[int, list[int]] = {}
for token, tid in vocab.items():
clean = token.replace("\u2581", "")
if len(clean) >= 2 and all(self._is_cjk(c) for c in clean):
char_ids = self.convert_tokens_to_ids(list(clean))
if all(c != self.unk_token_id for c in char_ids):
split_map[tid] = char_ids
return split_map
@staticmethod
def _is_cjk(c: str) -> bool:
return (
"\u4e00" <= c <= "\u9fff"
or "\u3400" <= c <= "\u4dbf"
or "\uf900" <= c <= "\ufaff"
or "\U00020000" <= c <= "\U0002a6df"
)
def _expand_ids(self, ids: list[int]) -> list[int]:
result: list[int] = []
for tid in ids:
expansion = self._split_map.get(tid)
if expansion is not None:
result.extend(expansion)
else:
result.append(tid)
return result
def encode(self, text, *args, **kwargs):
ids = super().encode(text, *args, **kwargs)
return self._expand_ids(ids)
def __call__(self, text, *args, **kwargs):
result = super().__call__(text, *args, **kwargs)
if hasattr(result, "input_ids"):
ids = result["input_ids"]
if isinstance(ids, list) and ids and isinstance(ids[0], list):
result["input_ids"] = [self._expand_ids(x) for x in ids]
if "attention_mask" in result:
result["attention_mask"] = [
[1] * len(x) for x in result["input_ids"]
]
elif isinstance(ids, list):
result["input_ids"] = self._expand_ids(ids)
if "attention_mask" in result:
result["attention_mask"] = [1] * len(result["input_ids"])
return result