| import os |
| import pickle |
| import shutil |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple |
|
|
| from huggingface_hub import hf_hub_download, snapshot_download |
| from huggingface_hub.utils import HfHubHTTPError |
| from transformers import PreTrainedTokenizer |
|
|
|
|
| class _BaseNanoGPTTokenizer: |
| """Lightweight wrapper used by the base (non-chat) checkpoints.""" |
|
|
| special_tokens = { |
| "bos": "<|bos|>", |
| "user_start": "<|user_start|>", |
| "user_end": "<|user_end|>", |
| "assistant_start": "<|assistant_start|>", |
| "assistant_end": "<|assistant_end|>", |
| "python_start": "<|python_start|>", |
| "python_end": "<|python_end|>", |
| "output_start": "<|output_start|>", |
| "output_end": "<|output_end|>", |
| } |
|
|
| def __init__(self, enc): |
| self.enc = enc |
| self.bos_token_id = enc.encode_single_token(self.special_tokens["bos"]) |
|
|
| @classmethod |
| def register_for_auto_class(cls, auto_class="AutoTokenizer"): |
| pass |
|
|
| @classmethod |
| def _load_encoding(cls, pretrained_model_name_or_path, **kwargs): |
| subfolder = kwargs.get("subfolder") |
| base_path = ( |
| os.path.join(pretrained_model_name_or_path, subfolder) |
| if subfolder |
| else pretrained_model_name_or_path |
| ) |
| local_tok_path = os.path.join(base_path, "tokenizer.pkl") |
| if os.path.isfile(local_tok_path): |
| with open(local_tok_path, "rb") as f: |
| return pickle.load(f) |
|
|
| snapshot_kwargs = {k: kwargs[k] for k in kwargs if k in { |
| "cache_dir", |
| "force_download", |
| "local_files_only", |
| "proxies", |
| "resume_download", |
| "revision", |
| "token", |
| "use_auth_token", |
| }} |
| token = snapshot_kwargs.pop("token", None) |
| if token is None: |
| token = snapshot_kwargs.pop("use_auth_token", None) |
| if token is not None: |
| snapshot_kwargs["token"] = token |
|
|
| snapshot_dir = snapshot_download(pretrained_model_name_or_path, **snapshot_kwargs) |
| tok_path = os.path.join(snapshot_dir, subfolder, "tokenizer.pkl") if subfolder else os.path.join(snapshot_dir, "tokenizer.pkl") |
| if not os.path.isfile(tok_path): |
| try: |
| tok_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="tokenizer.pkl", |
| subfolder=subfolder, |
| **snapshot_kwargs, |
| ) |
| except (HfHubHTTPError, OSError) as e: |
| raise ValueError( |
| f"Could not load tokenizer.pkl from {pretrained_model_name_or_path}. " |
| f"Make sure the path exists or the repo is accessible on the Hub." |
| ) from e |
| with open(tok_path, "rb") as f: |
| return pickle.load(f) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| enc = cls._load_encoding(pretrained_model_name_or_path, **kwargs) |
| return cls(enc) |
|
|
| def encode(self, text, prepend=None): |
| ids = self.enc.encode_ordinary(text) |
| if prepend is not None: |
| prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend) |
| ids.insert(0, prepend_id) |
| return ids |
|
|
| def decode(self, ids): |
| return self.enc.decode(ids) |
|
|
| def get_bos_token_id(self): |
| return self.bos_token_id |
|
|
| def encode_special(self, token): |
| return self.enc.encode_single_token(token) |
|
|
|
|
| class NanoGPTTokenizer(_BaseNanoGPTTokenizer): |
| pass |
|
|
|
|
| class NanoGPTChatTokenizer(PreTrainedTokenizer): |
| """Transformers-compatible tokenizer with chat helpers.""" |
|
|
| vocab_files_names = {"vocab_file": "tokenizer.pkl"} |
| model_input_names = ["input_ids"] |
|
|
| _special_tokens = { |
| "bos": "<|bos|>", |
| "user_start": "<|user_start|>", |
| "user_end": "<|user_end|>", |
| "assistant_start": "<|assistant_start|>", |
| "assistant_end": "<|assistant_end|>", |
| "python_start": "<|python_start|>", |
| "python_end": "<|python_end|>", |
| "output_start": "<|output_start|>", |
| "output_end": "<|output_end|>", |
| } |
|
|
| def __init__( |
| self, |
| vocab_file: str, |
| bos_token: str = "<|bos|>", |
| eos_token: str = "<|assistant_end|>", |
| pad_token: Optional[str] = None, |
| **kwargs, |
| ) -> None: |
| |
| with open(vocab_file, "rb") as f: |
| self.enc = pickle.load(f) |
| self.vocab_file = vocab_file |
|
|
| self.special_token_ids: Dict[str, int] = { |
| name: self.enc.encode_single_token(token) |
| for name, token in self._special_tokens.items() |
| } |
| self.bos_token_id = self.special_token_ids["bos"] |
| self.eos_token_id = self.special_token_ids["assistant_end"] |
| pad_token = pad_token or eos_token |
| self.pad_token_id = self.special_token_ids["assistant_end"] |
|
|
| self._build_vocabulary() |
|
|
| super().__init__( |
| bos_token=bos_token, |
| eos_token=eos_token, |
| pad_token=pad_token, |
| **kwargs, |
| ) |
|
|
| additional_special_tokens = [ |
| token |
| for key, token in self._special_tokens.items() |
| if token not in {bos_token, eos_token, pad_token} |
| ] |
| if additional_special_tokens: |
| self.add_special_tokens({"additional_special_tokens": additional_special_tokens}) |
| self.chat_template = kwargs.get("chat_template", getattr(self, "chat_template", None)) |
|
|
| |
| |
| |
| def _build_vocabulary(self) -> None: |
| id_to_token: Dict[int, str] = {} |
| token_to_id: Dict[str, int] = {} |
| for idx in range(self.enc.n_vocab): |
| token_bytes = self.enc.decode_single_token_bytes(idx) |
| token_str = token_bytes.decode("utf-8", errors="replace") |
| id_to_token[idx] = token_str |
| token_to_id[token_str] = idx |
| self._id_to_token = id_to_token |
| self._token_to_id = token_to_id |
|
|
| def get_vocab(self) -> Dict[str, int]: |
| return dict(self._token_to_id) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self.enc.n_vocab |
|
|
| def _tokenize(self, text: str, **kwargs) -> List[str]: |
| ids = self.enc.encode_ordinary(text) |
| return [self._id_to_token[i] for i in ids] |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| if token in self._token_to_id: |
| return self._token_to_id[token] |
| raise KeyError(f"Token not found in vocabulary: {token}") |
|
|
| def _convert_id_to_token(self, index: int) -> str: |
| return self._id_to_token[index] |
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| ids = [self._token_to_id[token] for token in tokens] |
| return self.enc.decode(ids) |
|
|
| def build_inputs_with_special_tokens( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| ) -> List[int]: |
| if token_ids_1 is not None: |
| return token_ids_0 + token_ids_1 |
| return token_ids_0 |
|
|
| def get_special_tokens_mask( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| ) -> List[int]: |
| all_ids = token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1 |
| return [1 if token in self.special_token_ids else 0 for token in all_ids] |
|
|
| def num_special_tokens_to_add(self, pair: bool = False) -> int: |
| return 0 |
|
|
| def save_vocabulary( |
| self, |
| save_directory: str, |
| filename_prefix: Optional[str] = None, |
| ) -> Tuple[str]: |
| os.makedirs(save_directory, exist_ok=True) |
| filename = "tokenizer.pkl" |
| if filename_prefix is not None: |
| filename = f"{filename_prefix}-{filename}" |
| save_path = os.path.join(save_directory, filename) |
| shutil.copyfile(self.vocab_file, save_path) |
| return (save_path,) |
|
|
| |
| |
| |
| def encode_special(self, token: str) -> int: |
| if token in self.special_token_ids: |
| return self.special_token_ids[token] |
| return self._token_to_id[token] |
|
|
| def _encode_text(self, text: str) -> List[int]: |
| return self.enc.encode_ordinary(text) |
|
|
| def _encode_python_block(self, token_id: int, content: str) -> List[int]: |
| tokens = [token_id] |
| tokens.extend(self._encode_text(content)) |
| closing = { |
| self.special_token_ids["python_start"]: self.special_token_ids["python_end"], |
| self.special_token_ids["output_start"]: self.special_token_ids["output_end"], |
| }[token_id] |
| tokens.append(closing) |
| return tokens |
|
|
| def _encode_assistant_content(self, content) -> List[int]: |
| if isinstance(content, str): |
| return self._encode_text(content) |
| if isinstance(content, list): |
| tokens: List[int] = [] |
| for part in content: |
| part_type = part.get("type", "text") |
| text = part.get("text", "") |
| if part_type == "text": |
| tokens.extend(self._encode_text(text)) |
| elif part_type == "python": |
| tokens.extend( |
| self._encode_python_block( |
| self.special_token_ids["python_start"], |
| text, |
| ) |
| ) |
| elif part_type == "python_output": |
| tokens.extend( |
| self._encode_python_block( |
| self.special_token_ids["output_start"], |
| text, |
| ) |
| ) |
| else: |
| raise ValueError(f"Unknown assistant content part: {part_type}") |
| return tokens |
| raise ValueError(f"Unsupported assistant content type: {type(content)}") |
|
|
| def _render_conversation_ids(self, conversation: Sequence[Dict[str, object]]) -> List[int]: |
| if not conversation: |
| raise ValueError("Conversation must contain at least one message") |
| messages = list(conversation) |
| if messages[0]["role"] == "system": |
| if len(messages) < 2 or messages[1]["role"] != "user": |
| raise ValueError("System message must be followed by a user message") |
| merged = dict(messages[1]) |
| merged["content"] = f"{messages[0]['content']}\n\n{messages[1]['content']}" |
| messages = [merged] + messages[2:] |
| ids: List[int] = [self.bos_token_id] |
| for idx, message in enumerate(messages): |
| expected_role = "user" if idx % 2 == 0 else "assistant" |
| role = message.get("role") |
| if role != expected_role: |
| raise ValueError(f"Expected role {expected_role}, received {role} at index {idx}") |
| content = message.get("content") |
| if expected_role == "user": |
| start = self.special_token_ids["user_start"] |
| end = self.special_token_ids["user_end"] |
| if not isinstance(content, str): |
| raise ValueError("User messages must contain string content") |
| ids.append(start) |
| ids.extend(self._encode_text(content)) |
| ids.append(end) |
| else: |
| start = self.special_token_ids["assistant_start"] |
| end = self.special_token_ids["assistant_end"] |
| ids.append(start) |
| ids.extend(self._encode_assistant_content(content)) |
| ids.append(end) |
| return ids |
|
|
| def apply_chat_template( |
| self, |
| conversation, |
| tokenize: bool = False, |
| add_generation_prompt: bool = False, |
| return_tensors: Optional[str] = None, |
| padding: bool = False, |
| truncation: bool = False, |
| max_length: Optional[int] = None, |
| **kwargs, |
| ): |
| if isinstance(conversation, dict) and "messages" in conversation: |
| messages = conversation["messages"] |
| else: |
| messages = conversation |
| token_ids = self._render_conversation_ids(messages) |
| if add_generation_prompt: |
| token_ids.append(self.special_token_ids["assistant_start"]) |
| if tokenize: |
| if return_tensors is not None: |
| return self( |
| [token_ids], |
| add_special_tokens=False, |
| return_tensors=return_tensors, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| **kwargs, |
| ) |
| return token_ids |
| return self.decode(token_ids, skip_special_tokens=False) |
|
|
| def encode_chat_message(self, role: str, content: str) -> List[int]: |
| rendered = self.apply_chat_template( |
| [ |
| {"role": role, "content": content}, |
| ], |
| tokenize=True, |
| add_generation_prompt=False, |
| ) |
| return rendered |
|
|
|
|
|
|
|
|
|
|