""" Custom Gemma Tokenizer for chat Format This tokenizer implements the chat format for message processing: Format: Uses the standard chat template with proper role labels (user/assistant) The chat format uses the model's built-in chat template and includes proper loss computation flags for training with "assistant" as the generation role. To save: uv run tokenizers/gemma_chat_tokenizer.py which will save the tokenizer to the repos/chat-gemma-tokenizer directory. mkdir repos/chat12b # copy model over cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/ # copy tokenizer over cp repos/chat-gemma-tokenizer/* repos/chat12b/ # upload to hf uv run upload_to_hf.py \ --folder repos/chat12b \ --repo-id tsor13/chat12b """ from typing import List, Dict, Any, Optional, Union from transformers import AutoTokenizer from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast from transformers.models.gemma.tokenization_gemma import GemmaTokenizer import warnings import difflib import json import os import sys class GemmaChatTokenizer(GemmaTokenizerFast): """ Custom tokenizer for Gemma models that implements chat format message processing. This tokenizer formats messages using the chat format where: - Messages use the standard chat template with proper role labels - Uses the model's built-in chat formatting - Loss is computed on the assistant sections (not output) Attributes: start_string (str): The starting string used for output generation (depends on tokenizer) end_string (str): The ending string used for output generation (depends on tokenizer) """ def __init__(self, *args, **kwargs): """ Initialize the custom tokenizer. Accepts the same arguments as GemmaTokenizerFast. """ super().__init__(*args, **kwargs) # For chat format, we use the tokenizer's own chat template # The start/end strings will be determined by the chat template self.start_string = "" # Will be set dynamically self.end_string = "" # Will be set dynamically # Add custom attributes to the tokenizer config for saving/loading if not hasattr(self, 'init_kwargs'): self.init_kwargs = {} self.init_kwargs['start_string'] = self.start_string self.init_kwargs['end_string'] = self.end_string @classmethod def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """ Load a tokenizer from a pretrained model or path. This method ensures our custom class is used instead of the base GemmaTokenizerFast. """ # Load the base tokenizer first to get all configuration base_tokenizer = GemmaTokenizerFast.from_pretrained( pretrained_model_name_or_path, *args, **kwargs ) # Create new instance of our custom class by copying the base tokenizer custom_tokenizer = cls.__new__(cls) # Copy all attributes from base tokenizer for attr, value in base_tokenizer.__dict__.items(): setattr(custom_tokenizer, attr, value) # Initialize our custom attributes for chat format custom_tokenizer.start_string = "" custom_tokenizer.end_string = "" # Update init_kwargs to include our custom attributes if not hasattr(custom_tokenizer, 'init_kwargs'): custom_tokenizer.init_kwargs = {} custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string return custom_tokenizer def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): """ Save the tokenizer to a directory, including custom configuration. """ # Call parent save method super().save_pretrained(save_directory, **kwargs) # Save custom configuration config_file = os.path.join(save_directory, "tokenizer_config.json") if os.path.exists(config_file): with open(config_file, 'r') as f: config = json.load(f) else: config = {} # Add our custom class info config["tokenizer_class"] = "GemmaChatTokenizer" config["start_string"] = self.start_string config["end_string"] = self.end_string # Point to our custom class in the uploaded file config["auto_map"] = { "AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"] } with open(config_file, 'w') as f: json.dump(config, f, indent=2) def messages_to_chat_messages( self, messages: List[Dict[str, Any]], start_generation: bool = False, default_user_message: str = "Generate.", ) -> List[Dict[str, Any]]: """ From messages (description / input / output) to chat messages (role / content) Uses the same logic as chat_utils.py with system messages. """ chat_prompt = """You are tasked with generating outputs from a particular, potentially stochastic, generative process. You will be given some combination of: - Description: A natural description of the generative process / data distribution - Input: An input on which to condition the generative process. - Example outputs: Example outputs from the process, either in a user message or as prior generations from a chat message. You may assume that any given outputs are exchangeable with one another (order-invariant) and generated from the same process (roughly i.i.d.). If the output data pertains to a single object, it just contains the output. If it contains multiple objects, use json formatting with keys for the name of the output variable. You will be provided at least either a description or an example output. Given these components, your job is to generate JUST the output in your response, roughly approximating the underlying generative process, maintaining any underlying stochasticity (if any is present). If you are asked to generate again, you will either be given an additional input to condition on, or will just be told to "Generate".""" chat_messages = [] system_message = chat_prompt has_description_or_output = False has_input = False for message in messages: if message["role"] == "description": system_message += "\n\nDescription: " + message["content"] chat_messages.append({"role": "system", "content": system_message}) has_description_or_output = True elif message["role"] == "input": has_input = True if not has_description_or_output: system_message += "\n\nExample Input: " + message["content"] else: chat_messages.append({"role": "user", "content": message["content"]}) elif message["role"] == "output": if not has_description_or_output: system_message += "\n\nExample Output: " + message["content"] chat_messages.append({"role": "system", "content": system_message}) has_description_or_output = True else: if not has_input: chat_messages.append({"role": "user", "content": default_user_message}) chat_messages.append({"role": "assistant", "content": message["content"]}) if len(chat_messages) == 0: # add system message chat_messages.append({"role": "system", "content": system_message}) # also add in empty user message for now for gemma chat_messages.append({"role": "user", "content": ""}) if len(chat_messages) == 1: # add in empty user message for now for gemma chat_messages.append({"role": "user", "content": ""}) # if last message is output and start_generation is true, add a default user message if start_generation and chat_messages[-1]["role"] == "assistant": chat_messages.append({"role": "user", "content": default_user_message}) return chat_messages def messages_to_loss_texts( self, messages: List[Dict[str, Any]], loss_on_start_token: bool = False, default_user_message: str = "Generate.", start_generation: bool = False, ) -> List[Dict[str, Any]]: """ From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training. Uses the chat format matching chat_utils.py with updated loss computation logic. """ # FOR NOW, OVERRIDING TO FALSE loss_on_start_token = False texts = [] chat_messages = self.messages_to_chat_messages(messages, start_generation=start_generation, default_user_message=default_user_message) # Apply chat template full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation) # replace with nothing full_text = full_text.replace("", "") text_to_split = full_text # now, find all places starting with model\n model_start_text = "model\n" # TODO - manual for now, change later first = True while model_start_text in text_to_split: # get location of model_start_text model_start_loc = text_to_split.find(model_start_text) split_ind = model_start_loc + len(model_start_text) text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:] # add to texts texts.append({"text": text_to_add, "compute_loss": False}) # get location of end_string end_string_loc = text_to_split.find(self.end_string) end_ind = end_string_loc + len(self.end_string) text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:] # Calculate loss on ALL assistant messages (removed conditional logic) texts.append({"text": text_to_add, "compute_loss": True}) first = False if len(text_to_split) > 0: texts.append({"text": text_to_split, "compute_loss": False}) if len(texts) == 0: breakpoint() return texts def messages_to_text( self, messages: List[Dict[str, Any]], start_generation: bool = False, ) -> str: """ Messages (description / input / output) to raw text (text). Uses the chat format matching chat_utils.py. """ texts = self.messages_to_loss_texts(messages, start_generation=start_generation) text = "".join([text["text"] for text in texts]) return text def tokenize_messages( self, messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]], start_generation: bool = False, **kwargs, ): """ For tokenizing from messages to texts. Supports batching. Good for generation """ if isinstance(messages, list) and isinstance(messages[0], list): # Handle list of lists of messages all_texts = [] for message_list in messages: texts = self.messages_to_text(message_list, start_generation) all_texts.append(texts) else: # Handle single list of messages texts = self.messages_to_text(messages, start_generation) all_texts = [texts] # Tokenize all texts processed = self(text=all_texts, **kwargs) return processed def tokenize_loss_texts( self, texts: List[Dict[str, Any]], loss_on_start_token: bool = False, loss_on_eos: bool = False, include_eos: bool = True, ): """ Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels). Needs more complex logic to handle the back and forth labeling. """ if loss_on_eos: raise ValueError("Loss on EOS is not currently supported.") # Handle single string input if isinstance(texts, str): processed = self(text=texts) # Add EOS token if needed if (self.eos_token_id is not None and processed["input_ids"][-1] != self.eos_token_id): processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] processed["attention_mask"] = processed["attention_mask"] + [1] return processed # Handle list of text dictionaries all_processed = [] all_texts = '' example_inds = [] dataset_inds = [] for i, item in enumerate(texts): processed = self(text=item["text"]) # Remove BOS token from all but first item if i != 0 and self.bos_token_id == processed["input_ids"][0]: processed["input_ids"] = processed["input_ids"][1:] processed["attention_mask"] = processed["attention_mask"][1:] # Remove EOS token if present at the end if processed["input_ids"][-1] == self.eos_token_id: processed["input_ids"] = processed["input_ids"][:-1] processed["attention_mask"] = processed["attention_mask"][:-1] # Check for EOS token in the middle (with special handling for <|im_end|>) if self.eos_token_id in processed["input_ids"]: if not self.decode([self.eos_token_id]) == "<|im_end|>": raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.") # Set labels based on compute_loss flag if item["compute_loss"]: processed["labels"] = processed["input_ids"].copy() else: processed["labels"] = [-100] * len(processed["input_ids"]) # Remove duplicate BOS tokens if all_processed: if processed["input_ids"][0] == self.bos_token_id: processed["input_ids"] = processed["input_ids"][1:] processed["attention_mask"] = processed["attention_mask"][1:] processed["labels"] = processed["labels"][1:] all_processed.append(processed) all_texts += item["text"] # Handle example indices this_num = -1 if 'example_ind' in item.keys(): if item["example_ind"] is not None: this_num = item["example_ind"] example_inds.extend([this_num] * len(processed["input_ids"])) # Handle dataset indices dataset_ind = -1 if "data_id" in item.keys(): if item["data_id"] is not None: dataset_ind = item["data_id"] dataset_inds.extend([dataset_ind] * len(processed["input_ids"])) # Combine all processed results processed = all_processed[0].copy() processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist] processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist] processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist] processed["example_inds"] = example_inds processed["data_ids"] = dataset_inds # Validate by tokenizing all_texts at once and comparing processed_all = self(text=all_texts) if len(processed_all["input_ids"]) != len(processed["input_ids"]): warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}") # Generate diff for debugging all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False) processed_text = self.decode(processed["input_ids"], skip_special_tokens=False) diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines()) diff_str = "\n".join(diff) print("Diff between texts:") print(diff_str) # Token diff all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]]) processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]]) token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines()) token_diff_str = "\n".join(token_diff) print("Diff between tokenized texts:") print(token_diff_str) # Add EOS token if needed if (self.eos_token_id is not None and processed["input_ids"][-1] != self.eos_token_id): processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] processed["example_inds"] = processed["example_inds"] + [-1] processed["attention_mask"] = processed["attention_mask"] + [1] if processed["labels"] is not None: if loss_on_eos: processed["labels"] = processed["labels"] + [self.eos_token_id] else: processed["labels"] = processed["labels"] + [-100] if "data_ids" in processed: processed["data_ids"] = processed["data_ids"] + [-1] if not include_eos: # check if EOS token is present if processed["input_ids"][-1] == self.eos_token_id: # remove EOS token processed["input_ids"] = processed["input_ids"][:-1] processed["attention_mask"] = processed["attention_mask"][:-1] processed["labels"] = processed["labels"][:-1] processed["example_inds"] = processed["example_inds"][:-1] processed["data_ids"] = processed["data_ids"][:-1] return processed def tokenize_messages( self, messages: List[Dict[str, Any]], loss_on_start_token: bool = False, loss_on_eos: bool = False, include_eos: bool = True, ) -> Dict[str, Any]: """ Intended for tokenize from messages to tokenized texts with the loss applied. """ # First convert messages to text with loss computation flags texts = self.messages_to_loss_texts(messages, loss_on_start_token) # Then tokenize the texts return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos) # Register tokenizer classes for AutoTokenizer AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer) if __name__ == "__main__": # Example usage # for first load custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it") # for subsequent loads # custom_tokenizer = GemmaChatTokenizer.from_pretrained("tsor13/chat-gemma-12b-pt") # custom_tokenizer = GemmaChatTokenizer.from_pretrained("repos/chat-gemma-12b-pt") # Test messages in role/content format test_messages = [ [ {"role": "description", "content": "Pick a number between 1 and 100"}, ], [ {"role": "description", "content": "This is a test task"}, {"role": "input", "content": "What is 2+2?"}, {"role": "output", "content": "4"}, {"role": "input", "content": "What is 3+3?"}, ], [ {"role": "description", "content": "This is a test task"}, {"role": "output", "content": "4"}, {"role": "output", "content": "10"}, {"role": "output", "content": "13"}, ], [ {"role": "output", "content": "4"}, {"role": "output", "content": "10"}, {"role": "output", "content": "13"}, ], [ {"role": "input", "content": "What is 2+2?"}, {"role": "output", "content": "4"}, {"role": "input", "content": "What is 3+3?"}, {"role": "output", "content": "10"}, {"role": "input", "content": "What is 4+4?"}, ], [ {"role": "description", "content": "DESCRIPTION"}, {"role": "input", "content": "INPUT1"}, {"role": "output", "content": "OUTPUT1"}, {"role": "input", "content": "INPUT2"}, {"role": "output", "content": "OUTPUT2"}, ], [ {"role": "description", "content": "DESCRIPTION"}, {"role": "output", "content": "OUTPUT1"}, {"role": "output", "content": "OUTPUT2"}, ], ] for messages in test_messages: # get messages to text_loss texts = custom_tokenizer.messages_to_loss_texts(messages) print("Texts with loss flags:") for i, text in enumerate(texts): print(f" {i}: {text}") text = custom_tokenizer.messages_to_text(messages, start_generation=True) print(f"\nFull text with generation prompt:") print(text) print("\nTesting save/load cycle:") # Test saving and loading tokenizer_path = "repos/chat-gemma-tokenizer" custom_tokenizer.save_pretrained(tokenizer_path) print("Tokenizer saved successfully!") # also save this file in the tokenizer_path import shutil shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py")) print("GemmaChatTokenizer.py saved successfully!")