| """Tokenizers for FLUX pipeline — T5 (SentencePiece) and CLIP (BPE). |
| |
| Both tokenizers produce mx.array token ID tensors ready for encoder input. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| from pathlib import Path |
|
|
| import mlx.core as mx |
|
|
| logger = logging.getLogger("image-server") |
|
|
|
|
| class T5Tokenizer: |
| """T5-XXL SentencePiece tokenizer. |
| |
| Loads ``spiece.model`` from the FLUX.1-schnell repo |
| (``tokenizer_2/spiece.model``). |
| """ |
|
|
| def __init__(self, spiece_path: str, max_length: int = 256): |
| import sentencepiece as spm |
|
|
| self._sp = spm.SentencePieceProcessor() |
| self._sp.Load(spiece_path) |
| self.max_length = max_length |
| self.pad_id = 0 |
|
|
| def tokenize(self, text: str) -> mx.array: |
| """Tokenize text → [1, max_length] int32 tensor.""" |
| ids = self._sp.Encode(text) |
| |
| if len(ids) > self.max_length: |
| ids = ids[: self.max_length] |
| |
| pad_len = self.max_length - len(ids) |
| if pad_len > 0: |
| ids = ids + [self.pad_id] * pad_len |
| return mx.array(ids, dtype=mx.int32).reshape(1, -1) |
|
|
|
|
| class CLIPTokenizer: |
| """CLIP BPE tokenizer. |
| |
| Loads ``vocab.json`` (token→id) and ``merges.txt`` (BPE merge rules) |
| from ``tokenizer/`` in the FLUX.1-schnell repo. |
| """ |
|
|
| BOS_ID = 49406 |
| EOS_ID = 49407 |
|
|
| def __init__(self, vocab_path: str, merges_path: str, max_length: int = 77): |
| |
| with open(vocab_path, encoding="utf-8") as f: |
| self._vocab: dict[str, int] = json.load(f) |
|
|
| |
| self._merges: list[tuple[str, str]] = [] |
| self._merge_rank: dict[tuple[str, str], int] = {} |
| with open(merges_path, encoding="utf-8") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| if not line or line.startswith("#"): |
| continue |
| parts = line.split() |
| if len(parts) == 2: |
| pair = (parts[0], parts[1]) |
| self._merges.append(pair) |
| self._merge_rank[pair] = i |
|
|
| self.max_length = max_length |
| self.pad_id = 0 |
|
|
| |
| import regex |
| self._pat = regex.compile( |
| r"""'s|'t|'re|'ve|'m|'ll|'d|""" |
| r"""[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", |
| regex.IGNORECASE, |
| ) |
|
|
| def _bpe(self, token: str) -> list[str]: |
| """Apply BPE merges to a single word token.""" |
| if len(token) <= 1: |
| return [token + "</w>"] if token else [] |
|
|
| |
| word = list(token[:-1]) + [token[-1] + "</w>"] |
|
|
| while len(word) > 1: |
| |
| best_pair = None |
| best_rank = float("inf") |
| for i in range(len(word) - 1): |
| pair = (word[i], word[i + 1]) |
| rank = self._merge_rank.get(pair, float("inf")) |
| if rank < best_rank: |
| best_rank = rank |
| best_pair = pair |
|
|
| if best_pair is None or best_rank == float("inf"): |
| break |
|
|
| |
| new_word = [] |
| i = 0 |
| while i < len(word): |
| if ( |
| i < len(word) - 1 |
| and word[i] == best_pair[0] |
| and word[i + 1] == best_pair[1] |
| ): |
| new_word.append(best_pair[0] + best_pair[1]) |
| i += 2 |
| else: |
| new_word.append(word[i]) |
| i += 1 |
| word = new_word |
|
|
| return word |
|
|
| def tokenize(self, text: str) -> mx.array: |
| """Tokenize text → [1, max_length] int32 tensor.""" |
| text = text.lower().strip() |
|
|
| ids = [self.BOS_ID] |
|
|
| |
| for match in self._pat.finditer(text): |
| word = match.group() |
| bpe_tokens = self._bpe(word) |
| for bt in bpe_tokens: |
| token_id = self._vocab.get(bt, 0) |
| ids.append(token_id) |
|
|
| ids.append(self.EOS_ID) |
|
|
| |
| if len(ids) > self.max_length: |
| ids = ids[: self.max_length - 1] + [self.EOS_ID] |
|
|
| |
| pad_len = self.max_length - len(ids) |
| if pad_len > 0: |
| ids = ids + [self.pad_id] * pad_len |
|
|
| return mx.array(ids, dtype=mx.int32).reshape(1, -1) |
|
|