File size: 2,489 Bytes
64566e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Qwen2 Tokenizer adapter for Z-Image-Turbo.

Uses the `tokenizers` library directly for fast BPE tokenization,
avoiding the slow AutoTokenizer.from_pretrained() initialization.
"""

from __future__ import annotations

import json
import logging
from pathlib import Path

logger = logging.getLogger("zimage-mlx")


class Qwen2Tokenizer:
    """Fast BPE tokenizer using tokenizers library."""

    def __init__(self, model_path: Path):
        from tokenizers import Tokenizer as HFTokenizer

        tokenizer_path = model_path / "tokenizer"
        json_file = tokenizer_path / "tokenizer.json"
        if not json_file.exists():
            json_file = model_path / "tokenizer.json"
        if not json_file.exists():
            raise FileNotFoundError(f"tokenizer.json not found in {model_path}")

        self._tokenizer = HFTokenizer.from_file(str(json_file))

        # Load chat template from tokenizer_config.json if available
        config_file = tokenizer_path / "tokenizer_config.json"
        if not config_file.exists():
            config_file = model_path / "tokenizer_config.json"
        self._chat_template = None
        if config_file.exists():
            with open(config_file) as f:
                cfg = json.load(f)
            self._chat_template = cfg.get("chat_template")

        logger.info("[ZImage] Tokenizer loaded: vocab_size=%d", self._tokenizer.get_vocab_size())

    def encode(self, text: str, max_length: int = 512) -> list[int]:
        """Encode text to token IDs."""
        encoded = self._tokenizer.encode(text)
        ids = encoded.ids
        if len(ids) > max_length:
            ids = ids[:max_length]
        return ids

    def apply_chat_template(self, prompt: str, max_length: int = 512) -> dict:
        """Apply Qwen3 chat template format and tokenize.

        Wraps prompt in chat format:
            <|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n

        Returns dict with 'input_ids' and 'attention_mask'.
        """
        # Build chat-formatted text manually (Qwen3 chat template)
        chat_text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
        encoded = self._tokenizer.encode(chat_text)
        ids = encoded.ids
        if len(ids) > max_length:
            ids = ids[:max_length]
        attn_mask = [1] * len(ids)
        return {"input_ids": ids, "attention_mask": attn_mask}

    @property
    def vocab_size(self) -> int:
        return self._tokenizer.get_vocab_size()