| """ |
| Hugging Face tokenizer wrapper for nanochat's rustbpe+tiktoken vocabulary. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import pickle |
| from typing import Dict, List, Optional, Tuple |
|
|
| import tiktoken |
| from transformers.tokenization_utils import PreTrainedTokenizer |
| from transformers import AutoTokenizer |
|
|
| try: |
| from .configuration_nanochat import NanoChatConfig |
| except ImportError: |
| from configuration_nanochat import NanoChatConfig |
|
|
| SPECIAL_TOKENS = [ |
| "<|bos|>", |
| "<|user_start|>", |
| "<|user_end|>", |
| "<|assistant_start|>", |
| "<|assistant_end|>", |
| "<|python_start|>", |
| "<|python_end|>", |
| "<|output_start|>", |
| "<|output_end|>", |
| ] |
|
|
|
|
| class NanoChatTokenizer(PreTrainedTokenizer): |
| vocab_files_names = {"tokenizer_file": "tokenizer/tokenizer.pkl"} |
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| def __init__(self, tokenizer_file: Optional[str] = None, **kwargs): |
| if tokenizer_file is None: |
| raise ValueError("tokenizer_file must be provided") |
|
|
| |
| init_kwargs = dict(kwargs) |
| bos_token = kwargs.pop("bos_token", "<|bos|>") |
| eos_token = kwargs.pop("eos_token", "<|bos|>") |
| unk_token = kwargs.pop("unk_token", "<|bos|>") |
| pad_token = kwargs.pop("pad_token", "<|bos|>") |
|
|
| resolved_path = tokenizer_file |
| if not os.path.isfile(resolved_path): |
| repo_id = init_kwargs.get("name_or_path") or init_kwargs.get("pretrained_model_name_or_path") |
| if repo_id: |
| from huggingface_hub import hf_hub_download |
|
|
| resolved_path = hf_hub_download( |
| repo_id, |
| tokenizer_file, |
| revision=init_kwargs.get("revision"), |
| subfolder=init_kwargs.get("subfolder"), |
| cache_dir=init_kwargs.get("cache_dir"), |
| token=init_kwargs.get("token"), |
| ) |
| if not os.path.isfile(resolved_path): |
| raise FileNotFoundError(f"Cannot locate tokenizer state at {tokenizer_file}") |
|
|
| with open(resolved_path, "rb") as handle: |
| self._encoding: tiktoken.Encoding = pickle.load(handle) |
|
|
| self._id_to_token: List[str] = [] |
| for token_id, token_bytes in enumerate(self._encoding.token_byte_values()): |
| token = token_bytes.decode("utf-8", errors="replace") |
| self._id_to_token.append(token) |
| self.vocab: Dict[str, int] = {token: idx for idx, token in enumerate(self._id_to_token)} |
|
|
| self._special_token_ids: Dict[str, int] = {} |
| for special in SPECIAL_TOKENS: |
| special_id = self._encoding.encode_single_token(special) |
| if special_id >= len(self._id_to_token): |
| self._id_to_token.extend([""] * (special_id - len(self._id_to_token) + 1)) |
| self._id_to_token[special_id] = special |
| self.vocab[special] = special_id |
| self._special_token_ids[special] = special_id |
|
|
| super().__init__( |
| bos_token=bos_token, |
| eos_token=eos_token, |
| unk_token=unk_token, |
| pad_token=pad_token, |
| **kwargs, |
| ) |
|
|
| |
| self.bos_token = bos_token or "<|bos|>" |
| self.eos_token = eos_token or "<|bos|>" |
| self.unk_token = unk_token or "<|bos|>" |
| self.pad_token = pad_token or "<|bos|>" |
|
|
| self.bos_token_id = self.vocab[self.bos_token] |
| self.eos_token_id = self.vocab[self.eos_token] |
| self.unk_token_id = self.vocab[self.unk_token] |
| self.pad_token_id = self.vocab[self.pad_token] |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self._id_to_token) |
|
|
| def get_vocab(self) -> Dict[str, int]: |
| return dict(self.vocab) |
|
|
| def _tokenize(self, text: str) -> List[str]: |
| token_ids = self._encoding.encode(text, allowed_special=set()) |
| return [self._id_to_token[token_id] for token_id in token_ids] |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| return self.vocab.get(token, self.unk_token_id) |
|
|
| def _convert_id_to_token(self, index: int) -> str: |
| return self._id_to_token[index] |
|
|
| def build_inputs_with_special_tokens( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| ) -> List[int]: |
| if token_ids_1 is not None: |
| raise ValueError("nanochat tokenizer only supports single sequences") |
| return [self.bos_token_id] + token_ids_0 |
|
|
| def create_token_type_ids_from_sequences( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| ) -> List[int]: |
| del token_ids_1 |
| return [0] * (len(token_ids_0) + 1) |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| target_dir = os.path.join(save_directory, "tokenizer") |
| os.makedirs(target_dir, exist_ok=True) |
| filename = (filename_prefix + "-" if filename_prefix else "") + "tokenizer.pkl" |
| dest_file = os.path.join(target_dir, filename) |
| with open(dest_file, "wb") as handle: |
| pickle.dump(self._encoding, handle) |
| return (dest_file,) |
|
|
| def _decode( |
| self, |
| token_ids: List[int], |
| skip_special_tokens: bool = False, |
| clean_up_tokenization_spaces: Optional[bool] = None, |
| spaces_between_special_tokens: bool = True, |
| **kwargs, |
| ) -> str: |
| del clean_up_tokenization_spaces, spaces_between_special_tokens, kwargs |
| if skip_special_tokens: |
| token_ids = [tid for tid in token_ids if tid not in self.all_special_ids] |
| return self._encoding.decode(token_ids) |
|
|
| def apply_chat_template( |
| self, |
| conversation, |
| add_generation_prompt: bool = False, |
| tokenize: bool = False, |
| return_tensors: Optional[str] = None, |
| **kwargs, |
| ): |
| if not isinstance(conversation, list) or not conversation: |
| raise ValueError("conversation must be a non-empty list of messages") |
|
|
| messages = conversation |
| if messages[0]["role"] == "system": |
| if len(messages) < 2 or messages[1]["role"] != "user": |
| raise ValueError("system prompt must be followed by a user message") |
| merged = messages[0]["content"] + "\n\n" + messages[1]["content"] |
| messages = [dict(messages[1], content=merged)] + messages[2:] |
|
|
| token_ids: List[int] = [self.bos_token_id] |
|
|
| def encode_text(text: str) -> List[int]: |
| return self._encoding.encode(text, allowed_special=set()) if text else [] |
|
|
| user_start = self._special_token_ids["<|user_start|>"] |
| user_end = self._special_token_ids["<|user_end|>"] |
| assistant_start = self._special_token_ids["<|assistant_start|>"] |
| assistant_end = self._special_token_ids["<|assistant_end|>"] |
| python_start = self._special_token_ids["<|python_start|>"] |
| python_end = self._special_token_ids["<|python_end|>"] |
| output_start = self._special_token_ids["<|output_start|>"] |
| output_end = self._special_token_ids["<|output_end|>"] |
|
|
| for idx, message in enumerate(messages): |
| expected_role = "user" if idx % 2 == 0 else "assistant" |
| if message["role"] != expected_role: |
| raise ValueError(f"Message {idx} should be from {expected_role}, got {message['role']}") |
| content = message["content"] |
| if message["role"] == "user": |
| if not isinstance(content, str): |
| raise ValueError("User messages must be plain strings") |
| token_ids.append(user_start) |
| token_ids.extend(encode_text(content)) |
| token_ids.append(user_end) |
| else: |
| token_ids.append(assistant_start) |
| if isinstance(content, str): |
| token_ids.extend(encode_text(content)) |
| elif isinstance(content, list): |
| for part in content: |
| part_type = part.get("type", "text") |
| value = part.get("text", "") |
| if part_type == "text": |
| token_ids.extend(encode_text(value)) |
| elif part_type == "python": |
| token_ids.append(python_start) |
| token_ids.extend(encode_text(value)) |
| token_ids.append(python_end) |
| elif part_type == "python_output": |
| token_ids.append(output_start) |
| token_ids.extend(encode_text(value)) |
| token_ids.append(output_end) |
| else: |
| raise ValueError(f"Unsupported assistant part type: {part_type}") |
| else: |
| raise ValueError(f"Assistant content must be str or list, got {type(content)}") |
| token_ids.append(assistant_end) |
|
|
| if add_generation_prompt: |
| token_ids.append(assistant_start) |
|
|
| if tokenize: |
| if return_tensors and return_tensors != "pt": |
| raise ValueError("Only return_tensors='pt' is supported") |
| import torch |
|
|
| return torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) |
| return self._encoding.decode(token_ids) |
|
|
|
|
| |
| AutoTokenizer.register(NanoChatConfig, NanoChatTokenizer) |
|
|