Z-Image-Turbo-MLX / tokenizer.py
illusion615's picture
Upload folder using huggingface_hub
64566e4 verified
"""Qwen2 Tokenizer adapter for Z-Image-Turbo.
Uses the `tokenizers` library directly for fast BPE tokenization,
avoiding the slow AutoTokenizer.from_pretrained() initialization.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
logger = logging.getLogger("zimage-mlx")
class Qwen2Tokenizer:
"""Fast BPE tokenizer using tokenizers library."""
def __init__(self, model_path: Path):
from tokenizers import Tokenizer as HFTokenizer
tokenizer_path = model_path / "tokenizer"
json_file = tokenizer_path / "tokenizer.json"
if not json_file.exists():
json_file = model_path / "tokenizer.json"
if not json_file.exists():
raise FileNotFoundError(f"tokenizer.json not found in {model_path}")
self._tokenizer = HFTokenizer.from_file(str(json_file))
# Load chat template from tokenizer_config.json if available
config_file = tokenizer_path / "tokenizer_config.json"
if not config_file.exists():
config_file = model_path / "tokenizer_config.json"
self._chat_template = None
if config_file.exists():
with open(config_file) as f:
cfg = json.load(f)
self._chat_template = cfg.get("chat_template")
logger.info("[ZImage] Tokenizer loaded: vocab_size=%d", self._tokenizer.get_vocab_size())
def encode(self, text: str, max_length: int = 512) -> list[int]:
"""Encode text to token IDs."""
encoded = self._tokenizer.encode(text)
ids = encoded.ids
if len(ids) > max_length:
ids = ids[:max_length]
return ids
def apply_chat_template(self, prompt: str, max_length: int = 512) -> dict:
"""Apply Qwen3 chat template format and tokenize.
Wraps prompt in chat format:
<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n
Returns dict with 'input_ids' and 'attention_mask'.
"""
# Build chat-formatted text manually (Qwen3 chat template)
chat_text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
encoded = self._tokenizer.encode(chat_text)
ids = encoded.ids
if len(ids) > max_length:
ids = ids[:max_length]
attn_mask = [1] * len(ids)
return {"input_ids": ids, "attention_mask": attn_mask}
@property
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()