"""Custom tokenizer for VoxCPM2 that splits multi-character Chinese tokens. VoxCPM2 was trained with ``mask_multichar_chinese_tokens`` which splits multi-character Chinese tokens (e.g. "你好" -> ["你", "好"]) into individual character IDs before embedding. The base LlamaTokenizerFast produces multi-character Chinese tokens that the model has never seen during training, yielding garbled Chinese audio output in downstream inference frameworks. This module provides ``VoxCPM2Tokenizer`` which transparently applies the character splitting inside ``encode()`` and ``__call__()``, so any downstream consumer (vLLM, vLLM-Omni, Nano-vLLM, etc.) gets correct single-character IDs without code changes. """ from transformers import LlamaTokenizerFast class VoxCPM2Tokenizer(LlamaTokenizerFast): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._split_map = self._build_split_map() def _build_split_map(self) -> dict[int, list[int]]: vocab = self.get_vocab() split_map: dict[int, list[int]] = {} for token, tid in vocab.items(): clean = token.replace("\u2581", "") if len(clean) >= 2 and all(self._is_cjk(c) for c in clean): char_ids = self.convert_tokens_to_ids(list(clean)) if all(c != self.unk_token_id for c in char_ids): split_map[tid] = char_ids return split_map @staticmethod def _is_cjk(c: str) -> bool: return ( "\u4e00" <= c <= "\u9fff" or "\u3400" <= c <= "\u4dbf" or "\uf900" <= c <= "\ufaff" or "\U00020000" <= c <= "\U0002a6df" ) def _expand_ids(self, ids: list[int]) -> list[int]: result: list[int] = [] for tid in ids: expansion = self._split_map.get(tid) if expansion is not None: result.extend(expansion) else: result.append(tid) return result def encode(self, text, *args, **kwargs): ids = super().encode(text, *args, **kwargs) return self._expand_ids(ids) def __call__(self, text, *args, **kwargs): result = super().__call__(text, *args, **kwargs) if hasattr(result, "input_ids"): ids = result["input_ids"] if isinstance(ids, list) and ids and isinstance(ids[0], list): result["input_ids"] = [self._expand_ids(x) for x in ids] if "attention_mask" in result: result["attention_mask"] = [ [1] * len(x) for x in result["input_ids"] ] elif isinstance(ids, list): result["input_ids"] = self._expand_ids(ids) if "attention_mask" in result: result["attention_mask"] = [1] * len(result["input_ids"]) return result