""" Custom Gemma Tokenizer for Bracket Format This tokenizer implements the bracket format for message processing: Format: {description}\n{input}\n<<{output}>>\n The bracket format wraps output content in double angle brackets (<>) and includes proper loss computation flags for training. To save: uv run tokenizers/gemma_bracket_tokenizer.py which will save the tokenizer to the repos/bracket-gemma-12b-pt directory. To load: uv run tokenizers/gemma_bracket_tokenizer.py which will load the tokenizer from the repos/bracket-gemma-12b-pt directory. To test: uv run tokenizers/gemma_bracket_tokenizer.py """ from typing import List, Dict, Any, Optional, Union from transformers import AutoTokenizer from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast from transformers.models.gemma.tokenization_gemma import GemmaTokenizer import warnings import difflib import json import os class GemmaBracketTokenizer(GemmaTokenizerFast): """ Custom tokenizer for Gemma models that implements bracket format message processing. This tokenizer formats messages using the bracket format where: - Description and input content are displayed as plain text with newlines - Output content is wrapped in double angle brackets: <> - Loss is computed on the bracketed output sections Attributes: start_string (str): The starting string used for output generation ("<<") end_string (str): The ending string used for output generation (">>") """ def __init__(self, *args, **kwargs): """ Initialize the custom tokenizer. Accepts the same arguments as GemmaTokenizerFast. """ super().__init__(*args, **kwargs) # Store the end string for bracket format # self.start_string = "<<" # self.end_string = ">>" self.start_string = "" self.end_string = "" # Add custom attributes to the tokenizer config for saving/loading if not hasattr(self, 'init_kwargs'): self.init_kwargs = {} self.init_kwargs['start_string'] = self.start_string self.init_kwargs['end_string'] = self.end_string @classmethod # def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """ Load a tokenizer from a pretrained model or path. This method ensures our custom class is used instead of the base GemmaTokenizerFast. """ # TODO - there's a warning here when loading the tokenizer from the hub # Load the base tokenizer first to get all configuration base_tokenizer = GemmaTokenizerFast.from_pretrained( pretrained_model_name_or_path, *args, **kwargs ) # Create new instance of our custom class by copying the base tokenizer custom_tokenizer = cls.__new__(cls) # Copy all attributes from base tokenizer for attr, value in base_tokenizer.__dict__.items(): setattr(custom_tokenizer, attr, value) # Initialize our custom attributes # custom_tokenizer.start_string = "<<" # custom_tokenizer.end_string = ">>" custom_tokenizer.start_string = "" custom_tokenizer.end_string = "" # Update init_kwargs to include our custom attributes if not hasattr(custom_tokenizer, 'init_kwargs'): custom_tokenizer.init_kwargs = {} custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string return custom_tokenizer def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): """ Save the tokenizer to a directory, including custom configuration. """ # Call parent save method super().save_pretrained(save_directory, **kwargs) # Save custom configuration config_file = os.path.join(save_directory, "tokenizer_config.json") if os.path.exists(config_file): with open(config_file, 'r') as f: config = json.load(f) else: config = {} # Add our custom class info config["tokenizer_class"] = "GemmaBracketTokenizer" config["start_string"] = self.start_string config["end_string"] = self.end_string # Point to our custom class in the uploaded file # config["tokenizer_class"] = "GemmaTokenizerCustom" config["auto_map"] = { "AutoTokenizer": ["gemma_bracket_tokenizer.GemmaBracketTokenizer", "gemma_bracket_tokenizer.GemmaBracketTokenizer"] } with open(config_file, 'w') as f: json.dump(config, f, indent=2) def messages_to_loss_texts( self, messages: List[Dict[str, Any]], loss_on_start_token: bool = False, ) -> List[Dict[str, Any]]: """ From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training. """ texts = [] for message in messages: role = message["role"] content = message["content"] if role == "description": text = f"{content}\n" texts.append({"text": text, "compute_loss": False, **message}) elif role == "input": text = f"{content}\n" texts.append({"text": text, "compute_loss": False, **message}) elif role == "output": if loss_on_start_token: # For output, wrap content in double angle brackets and include newline # text = f"<<{content}>>\n" text = f"{self.start_string}{content}{self.end_string}\n" texts.append({"text": text, "compute_loss": True, **message}) else: texts.append({"text": self.start_string, "compute_loss": False, **message}) text = f"{content}{self.end_string}\n" texts.append({"text": text, "compute_loss": True, **message}) else: raise ValueError(f"Unknown role: {role}. Must be description, input, or output.") # # Add generation prompt if start_generation is True # if start_generation: # texts.append({"text": self.start_string, "compute_loss": False}) return texts def messages_to_text( self, messages: List[Dict[str, Any]], start_generation: bool = False, ) -> str: """ Messages (description / input / output) to raw text (text). """ texts = self.messages_to_loss_texts(messages) text = "".join([text["text"] for text in texts]) if start_generation: text = text + self.start_string return text def tokenize_messages( self, messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]], start_generation: bool = False, **kwargs, ): """ For tokenizing from messages to texts. Supports batching. Good for generation """ if isinstance(messages, list) and isinstance(messages[0], list): # Handle list of lists of messages all_texts = [] for message_list in messages: texts = self.messages_to_raw_text(message_list, start_generation) all_texts.append(texts) else: # Handle single list of messages texts = self.messages_to_raw_text(messages, start_generation) all_texts = [texts] # Tokenize all texts processed = self(text=all_texts, **kwargs) return processed def tokenize_loss_texts( self, texts: List[Dict[str, Any]], loss_on_start_token: bool = False, loss_on_eos: bool = False, include_eos: bool = True, ): """ Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels). Needs more complex logic to handle the back and forth labeling. """ if loss_on_eos: raise ValueError("Loss on EOS is not currently supported.") # Handle single string input if isinstance(texts, str): processed = self(text=texts) # Add EOS token if needed if (self.eos_token_id is not None and processed["input_ids"][-1] != self.eos_token_id): processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] processed["attention_mask"] = processed["attention_mask"] + [1] return processed # Handle list of text dictionaries all_processed = [] all_texts = '' example_inds = [] dataset_inds = [] for i, item in enumerate(texts): processed = self(text=item["text"]) # Remove BOS token from all but first item if i != 0 and self.bos_token_id == processed["input_ids"][0]: processed["input_ids"] = processed["input_ids"][1:] processed["attention_mask"] = processed["attention_mask"][1:] # Remove EOS token if present at the end if processed["input_ids"][-1] == self.eos_token_id: processed["input_ids"] = processed["input_ids"][:-1] processed["attention_mask"] = processed["attention_mask"][:-1] # Check for EOS token in the middle (with special handling for <|im_end|>) if self.eos_token_id in processed["input_ids"]: if not self.decode([self.eos_token_id]) == "<|im_end|>": raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.") # Set labels based on compute_loss flag if item["compute_loss"]: processed["labels"] = processed["input_ids"].copy() else: processed["labels"] = [-100] * len(processed["input_ids"]) # Remove duplicate BOS tokens if all_processed: if processed["input_ids"][0] == self.bos_token_id: processed["input_ids"] = processed["input_ids"][1:] processed["attention_mask"] = processed["attention_mask"][1:] processed["labels"] = processed["labels"][1:] all_processed.append(processed) all_texts += item["text"] # Handle example indices this_num = -1 if 'example_ind' in item.keys(): if item["example_ind"] is not None: this_num = item["example_ind"] example_inds.extend([this_num] * len(processed["input_ids"])) # Handle dataset indices dataset_ind = -1 if "data_id" in item.keys(): if item["data_id"] is not None: dataset_ind = item["data_id"] dataset_inds.extend([dataset_ind] * len(processed["input_ids"])) # Combine all processed results processed = all_processed[0].copy() processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist] processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist] processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist] processed["example_inds"] = example_inds processed["data_ids"] = dataset_inds # Validate by tokenizing all_texts at once and comparing processed_all = self(text=all_texts) if len(processed_all["input_ids"]) != len(processed["input_ids"]): warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}") # Generate diff for debugging all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False) processed_text = self.decode(processed["input_ids"], skip_special_tokens=False) diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines()) diff_str = "\n".join(diff) print("Diff between texts:") print(diff_str) # Token diff all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]]) processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]]) token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines()) token_diff_str = "\n".join(token_diff) print("Diff between tokenized texts:") print(token_diff_str) # Add EOS token if needed if (self.eos_token_id is not None and processed["input_ids"][-1] != self.eos_token_id): processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] processed["example_inds"] = processed["example_inds"] + [-1] processed["attention_mask"] = processed["attention_mask"] + [1] if processed["labels"] is not None: if loss_on_eos: processed["labels"] = processed["labels"] + [self.eos_token_id] else: processed["labels"] = processed["labels"] + [-100] if "data_ids" in processed: processed["data_ids"] = processed["data_ids"] + [-1] if not include_eos: # check if EOS token is present if processed["input_ids"][-1] == self.eos_token_id: # remove EOS token processed["input_ids"] = processed["input_ids"][:-1] processed["attention_mask"] = processed["attention_mask"][:-1] processed["labels"] = processed["labels"][:-1] processed["example_inds"] = processed["example_inds"][:-1] processed["data_ids"] = processed["data_ids"][:-1] return processed def tokenize_messages( self, messages: List[Dict[str, Any]], loss_on_start_token: bool = False, loss_on_eos: bool = False, include_eos: bool = True, ) -> Dict[str, Any]: """ Intended for tokenize from messages to tokenized texts with the loss applied. """ # First convert messages to text with loss computation flags texts = self.messages_to_text_loss(messages, loss_on_start_token) # Then tokenize the texts return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos) # Register tokenizer classes for AutoTokenizer # Note: We register them separately to avoid conflicts AutoTokenizer.register("GemmaBracketTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaBracketTokenizer) # AutoTokenizer.register("GemmaBracketTokenizerSlow", slow_tokenizer_class=GemmaBracketTokenizerSlow, fast_tokenizer_class=None) if __name__ == "__main__": # Example usage try: # for first load custom_tokenizer = GemmaBracketTokenizer.from_gemma_pretrained("google/gemma-3-1b-pt") # for subsequent loads # custom_tokenizer = GemmaBracketTokenizer.from_pretrained("tsor13/bracket-gemma-12b-pt") # custom_tokenizer = GemmaBracketTokenizer.from_pretrained("repos/bracket-gemma-12b-pt") # Test messages in role/content format messages = [ {"role": "description", "content": "This is a test task"}, {"role": "input", "content": "What is 2+2?"}, {"role": "output", "content": "4"}, {"role": "input", "content": "What is 3+3?"}, # {"role": "output", "content": "6"} ] # tokenized = custom_tokenizer.tokenize_messages(messages, start_generation=True, return_tensors="pt") # print(tokenized) # get messages to text_loss texts = custom_tokenizer.messages_to_loss_texts(messages) print(texts) text = custom_tokenizer.messages_to_text(messages, start_generation=True) print(text) print("\nTesting save/load cycle:") # Test saving and loading tokenizer_path = "repos/bracket-gemma-tokenizer" custom_tokenizer.save_pretrained(tokenizer_path) print("Tokenizer saved successfully!") # also save this file in the tokenizer_path import shutil shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_bracket_tokenizer.py")) print("GemmaBracketTokenizer.py saved successfully!") except Exception as e: print(f"Error during testing: {e}") import traceback traceback.print_exc()