FLUX.1-schnell-MLX / tokenizers.py
illusion615's picture
Upload folder using huggingface_hub
31f3da5 verified
"""Tokenizers for FLUX pipeline — T5 (SentencePiece) and CLIP (BPE).
Both tokenizers produce mx.array token ID tensors ready for encoder input.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
import mlx.core as mx
logger = logging.getLogger("image-server")
class T5Tokenizer:
"""T5-XXL SentencePiece tokenizer.
Loads ``spiece.model`` from the FLUX.1-schnell repo
(``tokenizer_2/spiece.model``).
"""
def __init__(self, spiece_path: str, max_length: int = 256):
import sentencepiece as spm
self._sp = spm.SentencePieceProcessor()
self._sp.Load(spiece_path)
self.max_length = max_length
self.pad_id = 0
def tokenize(self, text: str) -> mx.array:
"""Tokenize text → [1, max_length] int32 tensor."""
ids = self._sp.Encode(text)
# Truncate
if len(ids) > self.max_length:
ids = ids[: self.max_length]
# Pad
pad_len = self.max_length - len(ids)
if pad_len > 0:
ids = ids + [self.pad_id] * pad_len
return mx.array(ids, dtype=mx.int32).reshape(1, -1)
class CLIPTokenizer:
"""CLIP BPE tokenizer.
Loads ``vocab.json`` (token→id) and ``merges.txt`` (BPE merge rules)
from ``tokenizer/`` in the FLUX.1-schnell repo.
"""
BOS_ID = 49406 # <|startoftext|>
EOS_ID = 49407 # <|endoftext|>
def __init__(self, vocab_path: str, merges_path: str, max_length: int = 77):
# Load vocab: token_str → id
with open(vocab_path, encoding="utf-8") as f:
self._vocab: dict[str, int] = json.load(f)
# Load BPE merges from merges.txt
self._merges: list[tuple[str, str]] = []
self._merge_rank: dict[tuple[str, str], int] = {}
with open(merges_path, encoding="utf-8") as f:
for i, line in enumerate(f):
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) == 2:
pair = (parts[0], parts[1])
self._merges.append(pair)
self._merge_rank[pair] = i
self.max_length = max_length
self.pad_id = 0
# pre/post processing regex (simplified CLIP pattern)
import regex
self._pat = regex.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d|"""
r"""[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
def _bpe(self, token: str) -> list[str]:
"""Apply BPE merges to a single word token."""
if len(token) <= 1:
return [token + "</w>"] if token else []
# Add end-of-word marker
word = list(token[:-1]) + [token[-1] + "</w>"]
while len(word) > 1:
# Find the highest-priority merge pair
best_pair = None
best_rank = float("inf")
for i in range(len(word) - 1):
pair = (word[i], word[i + 1])
rank = self._merge_rank.get(pair, float("inf"))
if rank < best_rank:
best_rank = rank
best_pair = pair
if best_pair is None or best_rank == float("inf"):
break
# Apply the merge
new_word = []
i = 0
while i < len(word):
if (
i < len(word) - 1
and word[i] == best_pair[0]
and word[i + 1] == best_pair[1]
):
new_word.append(best_pair[0] + best_pair[1])
i += 2
else:
new_word.append(word[i])
i += 1
word = new_word
return word
def tokenize(self, text: str) -> mx.array:
"""Tokenize text → [1, max_length] int32 tensor."""
text = text.lower().strip()
ids = [self.BOS_ID]
# Tokenize each word
for match in self._pat.finditer(text):
word = match.group()
bpe_tokens = self._bpe(word)
for bt in bpe_tokens:
token_id = self._vocab.get(bt, 0)
ids.append(token_id)
ids.append(self.EOS_ID)
# Truncate (keep BOS at start, EOS at end)
if len(ids) > self.max_length:
ids = ids[: self.max_length - 1] + [self.EOS_ID]
# Pad
pad_len = self.max_length - len(ids)
if pad_len > 0:
ids = ids + [self.pad_id] * pad_len
return mx.array(ids, dtype=mx.int32).reshape(1, -1)