rwkv-vl-test / tokenizer_core.py
Jo1uck's picture
Upload folder using huggingface_hub
cb105bf verified
Raw
History Blame Contribute Delete
15.7 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import ast
import json
import re
from dataclasses import dataclass
from typing import Any
DEFAULT_VOCAB_SIZE = 65536
DEFAULT_BOS_TOKEN = "\x16"
DEFAULT_EOS_TOKEN = "\x17"
DEFAULT_PAD_TOKEN = "\x17"
DEFAULT_UNK_TOKEN = "\x16"
DEFAULT_IMAGE_TOKEN = "<|image_pad|>"
DEFAULT_VISION_START_TOKEN = "<|vision_start|>"
DEFAULT_VISION_END_TOKEN = "<|vision_end|>"
DEFAULT_IMAGE_PLACEHOLDER_TOKEN = "<image>"
CHAT_TEMPLATE = (
"{% for message in messages %}"
"{{ '\\x16' + ('Assistant' if message['role'] == 'assistant' else 'System' if message['role'] == 'system' else 'User') + ':' }}"
"{% if message['content'] is string %}"
"{{ message['content'] }}"
"{% else %}"
"{% for item in message['content'] %}"
"{% if item['type'] == 'image' or item['type'] == 'image_url' %}{{ '<image>' }}{% elif item['type'] == 'text' %}{{ item['text'] }}{% endif %}"
"{% endfor %}"
"{% endif %}"
"{{ '\\x17' }}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '\\x16Assistant:' }}"
"{% if thinking is defined and thinking %}{{ ' <think>' }}{% endif %}"
"{% endif %}"
)
CHAT_TEMPLATE_FAKE_THINKING = (
"{% for message in messages %}"
"{{ '\\x16' + ('Assistant' if message['role'] == 'assistant' else 'System' if message['role'] == 'system' else 'User') + ':' }}"
"{% if message['role'] == 'assistant' %}{{ ' <think>\\n</think>\\n' }}{% endif %}"
"{% if message['content'] is string %}"
"{{ message['content'] }}"
"{% else %}"
"{% for item in message['content'] %}"
"{% if item['type'] == 'image' or item['type'] == 'image_url' %}{{ '<image>' }}{% elif item['type'] == 'text' %}{{ item['text'] }}{% endif %}"
"{% endfor %}"
"{% endif %}"
"{{ '\\x17' }}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '\\x16Assistant:' }}"
"{% if thinking is defined and thinking %}{{ ' <think>' }}{% endif %}"
"{% endif %}"
)
SPECIAL_TOKEN_TEXT_TO_ID = {
DEFAULT_VISION_START_TOKEN: 65530,
DEFAULT_VISION_END_TOKEN: 65531,
DEFAULT_IMAGE_TOKEN: 65532,
}
SPECIAL_TOKEN_ID_TO_TEXT = {
token_id: text for text, token_id in SPECIAL_TOKEN_TEXT_TO_ID.items()
}
@dataclass(frozen=True)
class RWKVSpecialTokens:
bos_token: str = DEFAULT_BOS_TOKEN
eos_token: str = DEFAULT_EOS_TOKEN
pad_token: str = DEFAULT_PAD_TOKEN
unk_token: str = DEFAULT_UNK_TOKEN
image_token: str = DEFAULT_IMAGE_TOKEN
vision_start_token: str = DEFAULT_VISION_START_TOKEN
vision_end_token: str = DEFAULT_VISION_END_TOKEN
image_placeholder_token: str = DEFAULT_IMAGE_PLACEHOLDER_TOKEN
class ByteTrie:
__slots__ = ("children", "value")
def __init__(self) -> None:
self.children: dict[int, "ByteTrie"] = {}
self.value: int | None = None
def add(self, token: bytes, token_id: int) -> None:
node = self
for byte in token:
node = node.children.setdefault(byte, ByteTrie())
node.value = token_id
def longest(self, data: bytes, start: int) -> tuple[int, int]:
node = self
best_id = None
best_end = start
idx = start
while idx < len(data) and data[idx] in node.children:
node = node.children[data[idx]]
idx += 1
if node.value is not None:
best_id = node.value
best_end = idx
if best_id is None:
raise ValueError(f"RWKV tokenizer could not encode byte at offset {start}")
return best_end, best_id
def _as_bytes(token: str | bytes) -> bytes:
return token.encode("utf-8") if isinstance(token, str) else token
def load_rwkv_vocab(vocab_file: str) -> tuple[dict[int, bytes], dict[bytes, int]]:
idx2token: dict[int, bytes] = {}
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
line = line.rstrip("\n")
first_space = line.index(" ")
last_space = line.rindex(" ")
token_id = int(line[:first_space])
token = ast.literal_eval(line[first_space + 1 : last_space])
if isinstance(token, str):
token = token.encode("utf-8")
if not isinstance(token, bytes):
raise ValueError(f"Invalid RWKV vocab token on line: {line}")
token_len = int(line[last_space + 1 :])
if len(token) != token_len:
raise ValueError(
f"Invalid RWKV vocab token length for id {token_id}: "
f"expected {token_len}, got {len(token)}"
)
idx2token[token_id] = token
token2idx = {token: idx for idx, token in idx2token.items()}
return idx2token, token2idx
def build_chat_template(template: str):
import jinja2
import jinja2.ext
import jinja2.sandbox
def raise_exception(msg):
raise jinja2.exceptions.TemplateError(msg)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
return json.dumps(
x,
ensure_ascii=ensure_ascii,
indent=indent,
separators=separators,
sort_keys=sort_keys,
)
def strftime_now(fmt):
from datetime import datetime
return datetime.now().strftime(fmt)
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2.ext.loopcontrols],
)
env.globals["raise_exception"] = raise_exception
env.globals["strftime_now"] = strftime_now
env.filters["tojson"] = tojson
return env.from_string(template)
class RWKVTokenizerCore:
def __init__(
self,
vocab_file: str,
*,
vocab_size: int = DEFAULT_VOCAB_SIZE,
special_tokens: RWKVSpecialTokens | None = None,
add_bos_token: bool = False,
add_eos_token: bool = False,
chat_template: str | None = None,
) -> None:
self.vocab_file = vocab_file
self.vocab_size = vocab_size
self.special_tokens = special_tokens or RWKVSpecialTokens()
self.default_add_bos = add_bos_token
self.default_add_eos = add_eos_token
self.idx2token, self.token2idx = load_rwkv_vocab(vocab_file)
self.root = ByteTrie()
for token, token_id in self.token2idx.items():
self.root.add(token, token_id)
self.special_token_text_to_id = dict(SPECIAL_TOKEN_TEXT_TO_ID)
self.special_token_id_to_text = dict(SPECIAL_TOKEN_ID_TO_TEXT)
self.bos_id = self.token_to_id(self.special_tokens.bos_token)
self.eos_id = self.token_to_id(self.special_tokens.eos_token)
self.pad_id = self.token_to_id(self.special_tokens.pad_token)
self.unk_id = self.token_to_id(self.special_tokens.unk_token)
self.image_id = self.token_to_id(self.special_tokens.image_token)
self.vision_start_id = self.token_to_id(self.special_tokens.vision_start_token)
self.vision_end_id = self.token_to_id(self.special_tokens.vision_end_token)
pattern_tokens = {
*self.special_token_text_to_id,
self.special_tokens.bos_token,
self.special_tokens.eos_token,
self.special_tokens.pad_token,
self.special_tokens.unk_token,
self.special_tokens.image_token,
self.special_tokens.vision_start_token,
self.special_tokens.vision_end_token,
}
pattern = "|".join(
re.escape(token)
for token in sorted(pattern_tokens, key=len, reverse=True)
if token
)
self.special_token_pattern = re.compile(f"({pattern})") if pattern else None
self._chat_template = None
if chat_template is not None:
self.set_chat_template(chat_template)
@property
def image_token(self) -> str:
return self.special_tokens.image_token
@property
def vision_start_token(self) -> str:
return self.special_tokens.vision_start_token
@property
def vision_end_token(self) -> str:
return self.special_tokens.vision_end_token
@property
def image_placeholder_token(self) -> str:
return self.special_tokens.image_placeholder_token
@property
def vision_image_token(self) -> str:
return f"{self.vision_start_token}{self.image_token}{self.vision_end_token}"
def template_kwargs(self) -> dict[str, Any]:
return {
"bos_token": self.special_tokens.bos_token,
"eos_token": self.special_tokens.eos_token,
"pad_token": self.special_tokens.pad_token,
"unk_token": self.special_tokens.unk_token,
"image_token": self.image_token,
"vision_start_token": self.vision_start_token,
"vision_end_token": self.vision_end_token,
"image_placeholder_token": self.image_placeholder_token,
}
def set_chat_template(self, template: str) -> None:
self._chat_template = build_chat_template(template)
def render_chat_template(
self,
messages: list[dict[str, Any]],
**kwargs,
) -> str:
if self._chat_template is None:
raise ValueError("No chat template set. Call set_chat_template() first.")
template_kwargs = self.template_kwargs()
template_kwargs.update(kwargs)
return self._chat_template.render(messages=messages, **template_kwargs)
def _encode_bytes(self, data: bytes) -> list[int]:
idx = 0
tokens = []
while idx < len(data):
idx, token_id = self.root.longest(data, idx)
tokens.append(token_id)
return tokens
def encode(
self,
text: str,
*,
add_bos: bool | None = None,
add_eos: bool | None = None,
) -> list[int]:
if add_bos is None:
add_bos = self.default_add_bos
if add_eos is None:
add_eos = self.default_add_eos
tokens: list[int] = []
chunks = (
self.special_token_pattern.split(text)
if self.special_token_pattern is not None
else [text]
)
for chunk in chunks:
if not chunk:
continue
token_id = self.token_to_id(chunk)
if token_id is not None and token_id != self.unk_id:
tokens.append(token_id)
else:
tokens.extend(self._encode_bytes(chunk.encode("utf-8")))
if add_bos and self.bos_id is not None:
tokens.insert(0, self.bos_id)
if add_eos and self.eos_id is not None:
tokens.append(self.eos_id)
return tokens
def decode(self, token_ids: list[int] | tuple[int, ...]) -> str:
return b"".join(self.idx2token[int(i)] for i in token_ids).decode(
"utf-8",
errors="replace",
)
def token_to_id(self, token: str | bytes | int) -> int | None:
if isinstance(token, int):
return token
if isinstance(token, str) and token in self.special_token_text_to_id:
return self.special_token_text_to_id[token]
return self.token2idx.get(_as_bytes(token))
def id_to_token(self, token_id: int) -> str | None:
token = self.idx2token.get(int(token_id))
if token is None:
return None
return token.decode("utf-8", errors="replace")
def get_vocab(self) -> dict[str, int]:
return {
token.decode("utf-8", errors="replace"): idx
for idx, token in self.idx2token.items()
}
def save_vocabulary(self, vocab_file: str) -> None:
with open(vocab_file, "w", encoding="utf-8") as writer:
for token_index, token in sorted(self.idx2token.items()):
writer.write(f"{token_index} {repr(token)} {len(token)}\n")
def expand_image_placeholders(
self,
rendered_text: str,
image_token_counts: list[int],
) -> str:
patterns = [self.vision_image_token, self.image_placeholder_token]
deduped_patterns = []
for pattern in patterns:
if pattern and pattern not in deduped_patterns:
deduped_patterns.append(pattern)
pieces: list[str] = []
pos = 0
image_idx = 0
while pos < len(rendered_text):
matches = []
for pattern in deduped_patterns:
start = rendered_text.find(pattern, pos)
if start != -1:
matches.append((start, -len(pattern), pattern))
if not matches:
pieces.append(rendered_text[pos:])
break
start, _, pattern = min(matches)
pieces.append(rendered_text[pos:start])
if image_idx >= len(image_token_counts):
pos = start + len(pattern)
continue
n_tokens = image_token_counts[image_idx]
pieces.append(
f"{self.vision_start_token}"
f"{self.image_token * n_tokens}"
f"{self.vision_end_token}"
)
image_idx += 1
pos = start + len(pattern)
while image_idx < len(image_token_counts):
n_tokens = image_token_counts[image_idx]
pieces.append(
f"{self.vision_start_token}"
f"{self.image_token * n_tokens}"
f"{self.vision_end_token}"
)
image_idx += 1
return "".join(pieces)
def render_mm_chat(
self,
messages: list[dict[str, Any]],
image_token_counts_by_message: list[list[int]],
*,
add_generation_prompt: bool = False,
) -> str:
if len(messages) != len(image_token_counts_by_message):
raise ValueError(
"image_token_counts_by_message must have one entry per message: "
f"got {len(image_token_counts_by_message)} counts for "
f"{len(messages)} messages"
)
rendered = self.render_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
).rstrip("\n")
image_token_counts = [
count
for message_counts in image_token_counts_by_message
for count in message_counts
]
return self.expand_image_placeholders(rendered, image_token_counts)
def assistant_token_spans(
self,
messages: list[dict[str, Any]],
image_token_counts_by_message: list[list[int]],
*,
add_bos: bool = True,
) -> list[tuple[int, int]]:
spans = []
for idx, message in enumerate(messages):
if message.get("role") != "assistant":
continue
start_text = self.render_mm_chat(
messages[:idx],
image_token_counts_by_message[:idx],
add_generation_prompt=True,
)
end_text = self.render_mm_chat(
messages[: idx + 1],
image_token_counts_by_message[: idx + 1],
add_generation_prompt=False,
)
start = len(self.encode(start_text, add_bos=add_bos, add_eos=False))
end = len(self.encode(end_text, add_bos=add_bos, add_eos=False))
if start < end:
spans.append((start, end))
return spans