chatv1 / gemma_chat_tokenizer.py
tsor13's picture
Initial upload of fine‑tuned Gemma + custom tokenizer
e0c9819 verified
"""
Custom Gemma Tokenizer for chat Format
This tokenizer implements the chat format for message processing:
Format: Uses the standard chat template with proper role labels (user/assistant)
The chat format uses the model's built-in chat template and includes proper
loss computation flags for training with "assistant" as the generation role.
To save:
uv run tokenizers/gemma_chat_tokenizer.py
which will save the tokenizer to the repos/chat-gemma-tokenizer directory.
mkdir repos/chat12b
# copy model over
cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/
# copy tokenizer over
cp repos/chat-gemma-tokenizer/* repos/chat12b/
# upload to hf
uv run upload_to_hf.py \
--folder repos/chat12b \
--repo-id tsor13/chat12b
"""
from typing import List, Dict, Any, Optional, Union
from transformers import AutoTokenizer
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
import warnings
import difflib
import json
import os
import sys
class GemmaChatTokenizer(GemmaTokenizerFast):
"""
Custom tokenizer for Gemma models that implements chat format message processing.
This tokenizer formats messages using the chat format where:
- Messages use the standard chat template with proper role labels
- Uses the model's built-in chat formatting
- Loss is computed on the assistant sections (not output)
Attributes:
start_string (str): The starting string used for output generation (depends on tokenizer)
end_string (str): The ending string used for output generation (depends on tokenizer)
"""
def __init__(self, *args, **kwargs):
"""
Initialize the custom tokenizer.
Accepts the same arguments as GemmaTokenizerFast.
"""
super().__init__(*args, **kwargs)
# For chat format, we use the tokenizer's own chat template
# The start/end strings will be determined by the chat template
self.start_string = "<start_of_turn>" # Will be set dynamically
self.end_string = "<end_of_turn>" # Will be set dynamically
# Add custom attributes to the tokenizer config for saving/loading
if not hasattr(self, 'init_kwargs'):
self.init_kwargs = {}
self.init_kwargs['start_string'] = self.start_string
self.init_kwargs['end_string'] = self.end_string
@classmethod
def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Load a tokenizer from a pretrained model or path.
This method ensures our custom class is used instead of the base GemmaTokenizerFast.
"""
# Load the base tokenizer first to get all configuration
base_tokenizer = GemmaTokenizerFast.from_pretrained(
pretrained_model_name_or_path, *args, **kwargs
)
# Create new instance of our custom class by copying the base tokenizer
custom_tokenizer = cls.__new__(cls)
# Copy all attributes from base tokenizer
for attr, value in base_tokenizer.__dict__.items():
setattr(custom_tokenizer, attr, value)
# Initialize our custom attributes for chat format
custom_tokenizer.start_string = "<start_of_turn>"
custom_tokenizer.end_string = "<end_of_turn>"
# Update init_kwargs to include our custom attributes
if not hasattr(custom_tokenizer, 'init_kwargs'):
custom_tokenizer.init_kwargs = {}
custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
return custom_tokenizer
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save the tokenizer to a directory, including custom configuration.
"""
# Call parent save method
super().save_pretrained(save_directory, **kwargs)
# Save custom configuration
config_file = os.path.join(save_directory, "tokenizer_config.json")
if os.path.exists(config_file):
with open(config_file, 'r') as f:
config = json.load(f)
else:
config = {}
# Add our custom class info
config["tokenizer_class"] = "GemmaChatTokenizer"
config["start_string"] = self.start_string
config["end_string"] = self.end_string
# Point to our custom class in the uploaded file
config["auto_map"] = {
"AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"]
}
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
def messages_to_chat_messages(
self,
messages: List[Dict[str, Any]],
start_generation: bool = False,
default_user_message: str = "Generate.",
) -> List[Dict[str, Any]]:
"""
From messages (description / input / output) to chat messages (role / content)
Uses the same logic as chat_utils.py with system messages.
"""
chat_prompt = """You are tasked with generating outputs from a particular, potentially stochastic, generative process. You will be given some combination of:
- Description: A natural description of the generative process / data distribution
- Input: An input on which to condition the generative process.
- Example outputs: Example outputs from the process, either in a user message or as prior generations from a chat message. You may assume that any given outputs are exchangeable with one another (order-invariant) and generated from the same process (roughly i.i.d.). If the output data pertains to a single object, it just contains the output. If it contains multiple objects, use json formatting with keys for the name of the output variable.
You will be provided at least either a description or an example output.
Given these components, your job is to generate JUST the output in your response, roughly approximating the underlying generative process, maintaining any underlying stochasticity (if any is present). If you are asked to generate again, you will either be given an additional input to condition on, or will just be told to "Generate"."""
chat_messages = []
system_message = chat_prompt
has_description_or_output = False
has_input = False
for message in messages:
if message["role"] == "description":
system_message += "\n\nDescription: " + message["content"]
chat_messages.append({"role": "system", "content": system_message})
has_description_or_output = True
elif message["role"] == "input":
has_input = True
if not has_description_or_output:
system_message += "\n\nExample Input: " + message["content"]
else:
chat_messages.append({"role": "user", "content": message["content"]})
elif message["role"] == "output":
if not has_description_or_output:
system_message += "\n\nExample Output: " + message["content"]
chat_messages.append({"role": "system", "content": system_message})
has_description_or_output = True
else:
if not has_input:
chat_messages.append({"role": "user", "content": default_user_message})
chat_messages.append({"role": "assistant", "content": message["content"]})
if len(chat_messages) == 0:
# add system message
chat_messages.append({"role": "system", "content": system_message})
# also add in empty user message for now for gemma
chat_messages.append({"role": "user", "content": ""})
if len(chat_messages) == 1:
# add in empty user message for now for gemma
chat_messages.append({"role": "user", "content": ""})
# if last message is output and start_generation is true, add a default user message
if start_generation and chat_messages[-1]["role"] == "assistant":
chat_messages.append({"role": "user", "content": default_user_message})
return chat_messages
def messages_to_loss_texts(
self,
messages: List[Dict[str, Any]],
loss_on_start_token: bool = False,
default_user_message: str = "Generate.",
start_generation: bool = False,
) -> List[Dict[str, Any]]:
"""
From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
Uses the chat format matching chat_utils.py with updated loss computation logic.
"""
# FOR NOW, OVERRIDING TO FALSE
loss_on_start_token = False
texts = []
chat_messages = self.messages_to_chat_messages(messages, start_generation=start_generation, default_user_message=default_user_message)
# Apply chat template
full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation)
# replace <bos> with nothing
full_text = full_text.replace("<bos>", "")
text_to_split = full_text
# now, find all places starting with <start_of_turn>model\n
model_start_text = "<start_of_turn>model\n" # TODO - manual for now, change later
first = True
while model_start_text in text_to_split:
# get location of model_start_text
model_start_loc = text_to_split.find(model_start_text)
split_ind = model_start_loc + len(model_start_text)
text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:]
# add to texts
texts.append({"text": text_to_add, "compute_loss": False})
# get location of end_string
end_string_loc = text_to_split.find(self.end_string)
end_ind = end_string_loc + len(self.end_string)
text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:]
# Calculate loss on ALL assistant messages (removed conditional logic)
texts.append({"text": text_to_add, "compute_loss": True})
first = False
if len(text_to_split) > 0:
texts.append({"text": text_to_split, "compute_loss": False})
if len(texts) == 0:
breakpoint()
return texts
def messages_to_text(
self,
messages: List[Dict[str, Any]],
start_generation: bool = False,
) -> str:
"""
Messages (description / input / output) to raw text (text).
Uses the chat format matching chat_utils.py.
"""
texts = self.messages_to_loss_texts(messages, start_generation=start_generation)
text = "".join([text["text"] for text in texts])
return text
def tokenize_messages(
self,
messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
start_generation: bool = False,
**kwargs,
):
"""
For tokenizing from messages to texts. Supports batching. Good for generation
"""
if isinstance(messages, list) and isinstance(messages[0], list):
# Handle list of lists of messages
all_texts = []
for message_list in messages:
texts = self.messages_to_text(message_list, start_generation)
all_texts.append(texts)
else:
# Handle single list of messages
texts = self.messages_to_text(messages, start_generation)
all_texts = [texts]
# Tokenize all texts
processed = self(text=all_texts, **kwargs)
return processed
def tokenize_loss_texts(
self,
texts: List[Dict[str, Any]],
loss_on_start_token: bool = False,
loss_on_eos: bool = False,
include_eos: bool = True,
):
"""
Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
Needs more complex logic to handle the back and forth labeling.
"""
if loss_on_eos:
raise ValueError("Loss on EOS is not currently supported.")
# Handle single string input
if isinstance(texts, str):
processed = self(text=texts)
# Add EOS token if needed
if (self.eos_token_id is not None and
processed["input_ids"][-1] != self.eos_token_id):
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
processed["attention_mask"] = processed["attention_mask"] + [1]
return processed
# Handle list of text dictionaries
all_processed = []
all_texts = ''
example_inds = []
dataset_inds = []
for i, item in enumerate(texts):
processed = self(text=item["text"])
# Remove BOS token from all but first item
if i != 0 and self.bos_token_id == processed["input_ids"][0]:
processed["input_ids"] = processed["input_ids"][1:]
processed["attention_mask"] = processed["attention_mask"][1:]
# Remove EOS token if present at the end
if processed["input_ids"][-1] == self.eos_token_id:
processed["input_ids"] = processed["input_ids"][:-1]
processed["attention_mask"] = processed["attention_mask"][:-1]
# Check for EOS token in the middle (with special handling for <|im_end|>)
if self.eos_token_id in processed["input_ids"]:
if not self.decode([self.eos_token_id]) == "<|im_end|>":
raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
# Set labels based on compute_loss flag
if item["compute_loss"]:
processed["labels"] = processed["input_ids"].copy()
else:
processed["labels"] = [-100] * len(processed["input_ids"])
# Remove duplicate BOS tokens
if all_processed:
if processed["input_ids"][0] == self.bos_token_id:
processed["input_ids"] = processed["input_ids"][1:]
processed["attention_mask"] = processed["attention_mask"][1:]
processed["labels"] = processed["labels"][1:]
all_processed.append(processed)
all_texts += item["text"]
# Handle example indices
this_num = -1
if 'example_ind' in item.keys():
if item["example_ind"] is not None:
this_num = item["example_ind"]
example_inds.extend([this_num] * len(processed["input_ids"]))
# Handle dataset indices
dataset_ind = -1
if "data_id" in item.keys():
if item["data_id"] is not None:
dataset_ind = item["data_id"]
dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
# Combine all processed results
processed = all_processed[0].copy()
processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
processed["example_inds"] = example_inds
processed["data_ids"] = dataset_inds
# Validate by tokenizing all_texts at once and comparing
processed_all = self(text=all_texts)
if len(processed_all["input_ids"]) != len(processed["input_ids"]):
warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
# Generate diff for debugging
all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
diff_str = "\n".join(diff)
print("Diff between texts:")
print(diff_str)
# Token diff
all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
token_diff_str = "\n".join(token_diff)
print("Diff between tokenized texts:")
print(token_diff_str)
# Add EOS token if needed
if (self.eos_token_id is not None and
processed["input_ids"][-1] != self.eos_token_id):
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
processed["example_inds"] = processed["example_inds"] + [-1]
processed["attention_mask"] = processed["attention_mask"] + [1]
if processed["labels"] is not None:
if loss_on_eos:
processed["labels"] = processed["labels"] + [self.eos_token_id]
else:
processed["labels"] = processed["labels"] + [-100]
if "data_ids" in processed:
processed["data_ids"] = processed["data_ids"] + [-1]
if not include_eos:
# check if EOS token is present
if processed["input_ids"][-1] == self.eos_token_id:
# remove EOS token
processed["input_ids"] = processed["input_ids"][:-1]
processed["attention_mask"] = processed["attention_mask"][:-1]
processed["labels"] = processed["labels"][:-1]
processed["example_inds"] = processed["example_inds"][:-1]
processed["data_ids"] = processed["data_ids"][:-1]
return processed
def tokenize_messages(
self,
messages: List[Dict[str, Any]],
loss_on_start_token: bool = False,
loss_on_eos: bool = False,
include_eos: bool = True,
) -> Dict[str, Any]:
"""
Intended for tokenize from messages to tokenized texts with the loss applied.
"""
# First convert messages to text with loss computation flags
texts = self.messages_to_loss_texts(messages, loss_on_start_token)
# Then tokenize the texts
return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos)
# Register tokenizer classes for AutoTokenizer
AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer)
if __name__ == "__main__":
# Example usage
# for first load
custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
# for subsequent loads
# custom_tokenizer = GemmaChatTokenizer.from_pretrained("tsor13/chat-gemma-12b-pt")
# custom_tokenizer = GemmaChatTokenizer.from_pretrained("repos/chat-gemma-12b-pt")
# Test messages in role/content format
test_messages = [
[
{"role": "description", "content": "Pick a number between 1 and 100"},
],
[
{"role": "description", "content": "This is a test task"},
{"role": "input", "content": "What is 2+2?"},
{"role": "output", "content": "4"},
{"role": "input", "content": "What is 3+3?"},
],
[
{"role": "description", "content": "This is a test task"},
{"role": "output", "content": "4"},
{"role": "output", "content": "10"},
{"role": "output", "content": "13"},
],
[
{"role": "output", "content": "4"},
{"role": "output", "content": "10"},
{"role": "output", "content": "13"},
],
[
{"role": "input", "content": "What is 2+2?"},
{"role": "output", "content": "4"},
{"role": "input", "content": "What is 3+3?"},
{"role": "output", "content": "10"},
{"role": "input", "content": "What is 4+4?"},
],
[
{"role": "description", "content": "DESCRIPTION"},
{"role": "input", "content": "INPUT1"},
{"role": "output", "content": "OUTPUT1"},
{"role": "input", "content": "INPUT2"},
{"role": "output", "content": "OUTPUT2"},
],
[
{"role": "description", "content": "DESCRIPTION"},
{"role": "output", "content": "OUTPUT1"},
{"role": "output", "content": "OUTPUT2"},
],
]
for messages in test_messages:
# get messages to text_loss
texts = custom_tokenizer.messages_to_loss_texts(messages)
print("Texts with loss flags:")
for i, text in enumerate(texts):
print(f" {i}: {text}")
text = custom_tokenizer.messages_to_text(messages, start_generation=True)
print(f"\nFull text with generation prompt:")
print(text)
print("\nTesting save/load cycle:")
# Test saving and loading
tokenizer_path = "repos/chat-gemma-tokenizer"
custom_tokenizer.save_pretrained(tokenizer_path)
print("Tokenizer saved successfully!")
# also save this file in the tokenizer_path
import shutil
shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py"))
print("GemmaChatTokenizer.py saved successfully!")