|
|
""" |
|
|
Custom Gemma Tokenizer for chat Format |
|
|
|
|
|
This tokenizer implements the chat format for message processing: |
|
|
Format: Uses the standard chat template with proper role labels (user/assistant) |
|
|
|
|
|
The chat format uses the model's built-in chat template and includes proper |
|
|
loss computation flags for training with "assistant" as the generation role. |
|
|
|
|
|
To save: |
|
|
uv run tokenizers/gemma_chat_tokenizer.py |
|
|
which will save the tokenizer to the repos/chat-gemma-tokenizer directory. |
|
|
mkdir repos/chat12b |
|
|
# copy model over |
|
|
cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/ |
|
|
# copy tokenizer over |
|
|
cp repos/chat-gemma-tokenizer/* repos/chat12b/ |
|
|
# upload to hf |
|
|
|
|
|
uv run upload_to_hf.py \ |
|
|
--folder repos/chat12b \ |
|
|
--repo-id tsor13/chat12b |
|
|
""" |
|
|
|
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
from transformers import AutoTokenizer |
|
|
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast |
|
|
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer |
|
|
import warnings |
|
|
import difflib |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
class GemmaChatTokenizer(GemmaTokenizerFast): |
|
|
""" |
|
|
Custom tokenizer for Gemma models that implements chat format message processing. |
|
|
|
|
|
This tokenizer formats messages using the chat format where: |
|
|
- Messages use the standard chat template with proper role labels |
|
|
- Uses the model's built-in chat formatting |
|
|
- Loss is computed on the assistant sections (not output) |
|
|
|
|
|
Attributes: |
|
|
start_string (str): The starting string used for output generation (depends on tokenizer) |
|
|
end_string (str): The ending string used for output generation (depends on tokenizer) |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
""" |
|
|
Initialize the custom tokenizer. |
|
|
|
|
|
Accepts the same arguments as GemmaTokenizerFast. |
|
|
""" |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.start_string = "<start_of_turn>" |
|
|
self.end_string = "<end_of_turn>" |
|
|
|
|
|
|
|
|
if not hasattr(self, 'init_kwargs'): |
|
|
self.init_kwargs = {} |
|
|
self.init_kwargs['start_string'] = self.start_string |
|
|
self.init_kwargs['end_string'] = self.end_string |
|
|
|
|
|
@classmethod |
|
|
def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
|
""" |
|
|
Load a tokenizer from a pretrained model or path. |
|
|
|
|
|
This method ensures our custom class is used instead of the base GemmaTokenizerFast. |
|
|
""" |
|
|
|
|
|
base_tokenizer = GemmaTokenizerFast.from_pretrained( |
|
|
pretrained_model_name_or_path, *args, **kwargs |
|
|
) |
|
|
|
|
|
|
|
|
custom_tokenizer = cls.__new__(cls) |
|
|
|
|
|
|
|
|
for attr, value in base_tokenizer.__dict__.items(): |
|
|
setattr(custom_tokenizer, attr, value) |
|
|
|
|
|
|
|
|
custom_tokenizer.start_string = "<start_of_turn>" |
|
|
custom_tokenizer.end_string = "<end_of_turn>" |
|
|
|
|
|
|
|
|
if not hasattr(custom_tokenizer, 'init_kwargs'): |
|
|
custom_tokenizer.init_kwargs = {} |
|
|
custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string |
|
|
custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string |
|
|
|
|
|
return custom_tokenizer |
|
|
|
|
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): |
|
|
""" |
|
|
Save the tokenizer to a directory, including custom configuration. |
|
|
""" |
|
|
|
|
|
super().save_pretrained(save_directory, **kwargs) |
|
|
|
|
|
|
|
|
config_file = os.path.join(save_directory, "tokenizer_config.json") |
|
|
if os.path.exists(config_file): |
|
|
with open(config_file, 'r') as f: |
|
|
config = json.load(f) |
|
|
else: |
|
|
config = {} |
|
|
|
|
|
|
|
|
config["tokenizer_class"] = "GemmaChatTokenizer" |
|
|
config["start_string"] = self.start_string |
|
|
config["end_string"] = self.end_string |
|
|
|
|
|
config["auto_map"] = { |
|
|
"AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"] |
|
|
} |
|
|
|
|
|
with open(config_file, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
|
|
|
def messages_to_loss_texts( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
loss_on_start_token: bool = False, |
|
|
default_user_message: str = "Generate.", |
|
|
start_generation: bool = False, |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training. |
|
|
Uses the chat format matching chat_utils.py. |
|
|
""" |
|
|
|
|
|
loss_on_start_token = False |
|
|
|
|
|
texts = [] |
|
|
chat_messages = [] |
|
|
has_input = False |
|
|
has_description = False |
|
|
|
|
|
|
|
|
for message in messages: |
|
|
if message["role"] == "description": |
|
|
chat_messages.append({"role": "system", "content": "Generate something that fits this description. Don't generate anything else, just the desired generation output.\nDescription: " + message["content"]}) |
|
|
has_description = True |
|
|
elif message["role"] == "input": |
|
|
has_input = True |
|
|
chat_messages.append({"role": "user", "content": message["content"]}) |
|
|
elif message["role"] == "output": |
|
|
if not has_input: |
|
|
chat_messages.append({"role": "user", "content": default_user_message}) |
|
|
chat_messages.append({"role": "assistant", "content": message["content"]}) |
|
|
|
|
|
if start_generation and chat_messages[-1]["role"] == "assistant": |
|
|
chat_messages.append({"role": "user", "content": default_user_message}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not has_input and len(chat_messages) == 1: |
|
|
|
|
|
chat_messages.append({"role": "user", "content": default_user_message}) |
|
|
|
|
|
|
|
|
full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation) |
|
|
|
|
|
full_text = full_text.replace("<bos>", "") |
|
|
|
|
|
text_to_split = full_text |
|
|
|
|
|
model_start_text = "<start_of_turn>model\n" |
|
|
first = True |
|
|
while model_start_text in text_to_split: |
|
|
|
|
|
model_start_loc = text_to_split.find(model_start_text) |
|
|
split_ind = model_start_loc + len(model_start_text) |
|
|
text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:] |
|
|
|
|
|
texts.append({"text": text_to_add, "compute_loss": False}) |
|
|
|
|
|
end_string_loc = text_to_split.find(self.end_string) |
|
|
end_ind = end_string_loc + len(self.end_string) |
|
|
text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:] |
|
|
if first and not has_description: |
|
|
texts.append({"text": text_to_add, "compute_loss": False}) |
|
|
else: |
|
|
texts.append({"text": text_to_add, "compute_loss": True}) |
|
|
first = False |
|
|
if len(text_to_split) > 0: |
|
|
texts.append({"text": text_to_split, "compute_loss": False}) |
|
|
|
|
|
return texts |
|
|
|
|
|
def messages_to_text( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
start_generation: bool = False, |
|
|
) -> str: |
|
|
""" |
|
|
Messages (description / input / output) to raw text (text). |
|
|
Uses the chat format matching chat_utils.py. |
|
|
""" |
|
|
texts = self.messages_to_loss_texts(messages, start_generation=start_generation) |
|
|
text = "".join([text["text"] for text in texts]) |
|
|
return text |
|
|
|
|
|
|
|
|
def tokenize_messages( |
|
|
self, |
|
|
messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]], |
|
|
start_generation: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
For tokenizing from messages to texts. Supports batching. Good for generation |
|
|
""" |
|
|
if isinstance(messages, list) and isinstance(messages[0], list): |
|
|
|
|
|
all_texts = [] |
|
|
for message_list in messages: |
|
|
texts = self.messages_to_text(message_list, start_generation) |
|
|
all_texts.append(texts) |
|
|
else: |
|
|
|
|
|
texts = self.messages_to_text(messages, start_generation) |
|
|
all_texts = [texts] |
|
|
|
|
|
|
|
|
processed = self(text=all_texts, **kwargs) |
|
|
return processed |
|
|
|
|
|
|
|
|
def tokenize_loss_texts( |
|
|
self, |
|
|
texts: List[Dict[str, Any]], |
|
|
loss_on_start_token: bool = False, |
|
|
loss_on_eos: bool = False, |
|
|
include_eos: bool = True, |
|
|
): |
|
|
""" |
|
|
Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels). |
|
|
|
|
|
Needs more complex logic to handle the back and forth labeling. |
|
|
""" |
|
|
if loss_on_eos: |
|
|
raise ValueError("Loss on EOS is not currently supported.") |
|
|
|
|
|
|
|
|
if isinstance(texts, str): |
|
|
processed = self(text=texts) |
|
|
|
|
|
if (self.eos_token_id is not None and |
|
|
processed["input_ids"][-1] != self.eos_token_id): |
|
|
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] |
|
|
processed["attention_mask"] = processed["attention_mask"] + [1] |
|
|
return processed |
|
|
|
|
|
|
|
|
all_processed = [] |
|
|
all_texts = '' |
|
|
example_inds = [] |
|
|
dataset_inds = [] |
|
|
|
|
|
|
|
|
for i, item in enumerate(texts): |
|
|
processed = self(text=item["text"]) |
|
|
|
|
|
|
|
|
if i != 0 and self.bos_token_id == processed["input_ids"][0]: |
|
|
processed["input_ids"] = processed["input_ids"][1:] |
|
|
processed["attention_mask"] = processed["attention_mask"][1:] |
|
|
|
|
|
|
|
|
if processed["input_ids"][-1] == self.eos_token_id: |
|
|
processed["input_ids"] = processed["input_ids"][:-1] |
|
|
processed["attention_mask"] = processed["attention_mask"][:-1] |
|
|
|
|
|
|
|
|
if self.eos_token_id in processed["input_ids"]: |
|
|
if not self.decode([self.eos_token_id]) == "<|im_end|>": |
|
|
raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.") |
|
|
|
|
|
|
|
|
if item["compute_loss"]: |
|
|
processed["labels"] = processed["input_ids"].copy() |
|
|
else: |
|
|
processed["labels"] = [-100] * len(processed["input_ids"]) |
|
|
|
|
|
|
|
|
if all_processed: |
|
|
if processed["input_ids"][0] == self.bos_token_id: |
|
|
processed["input_ids"] = processed["input_ids"][1:] |
|
|
processed["attention_mask"] = processed["attention_mask"][1:] |
|
|
processed["labels"] = processed["labels"][1:] |
|
|
|
|
|
all_processed.append(processed) |
|
|
all_texts += item["text"] |
|
|
|
|
|
|
|
|
this_num = -1 |
|
|
if 'example_ind' in item.keys(): |
|
|
if item["example_ind"] is not None: |
|
|
this_num = item["example_ind"] |
|
|
example_inds.extend([this_num] * len(processed["input_ids"])) |
|
|
|
|
|
|
|
|
dataset_ind = -1 |
|
|
if "data_id" in item.keys(): |
|
|
if item["data_id"] is not None: |
|
|
dataset_ind = item["data_id"] |
|
|
dataset_inds.extend([dataset_ind] * len(processed["input_ids"])) |
|
|
|
|
|
|
|
|
processed = all_processed[0].copy() |
|
|
processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist] |
|
|
processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist] |
|
|
processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist] |
|
|
processed["example_inds"] = example_inds |
|
|
processed["data_ids"] = dataset_inds |
|
|
|
|
|
|
|
|
processed_all = self(text=all_texts) |
|
|
if len(processed_all["input_ids"]) != len(processed["input_ids"]): |
|
|
warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}") |
|
|
|
|
|
|
|
|
all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False) |
|
|
processed_text = self.decode(processed["input_ids"], skip_special_tokens=False) |
|
|
|
|
|
diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines()) |
|
|
diff_str = "\n".join(diff) |
|
|
print("Diff between texts:") |
|
|
print(diff_str) |
|
|
|
|
|
|
|
|
all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]]) |
|
|
processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]]) |
|
|
token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines()) |
|
|
token_diff_str = "\n".join(token_diff) |
|
|
print("Diff between tokenized texts:") |
|
|
print(token_diff_str) |
|
|
|
|
|
|
|
|
if (self.eos_token_id is not None and |
|
|
processed["input_ids"][-1] != self.eos_token_id): |
|
|
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] |
|
|
processed["example_inds"] = processed["example_inds"] + [-1] |
|
|
processed["attention_mask"] = processed["attention_mask"] + [1] |
|
|
if processed["labels"] is not None: |
|
|
if loss_on_eos: |
|
|
processed["labels"] = processed["labels"] + [self.eos_token_id] |
|
|
else: |
|
|
processed["labels"] = processed["labels"] + [-100] |
|
|
if "data_ids" in processed: |
|
|
processed["data_ids"] = processed["data_ids"] + [-1] |
|
|
|
|
|
if not include_eos: |
|
|
|
|
|
if processed["input_ids"][-1] == self.eos_token_id: |
|
|
|
|
|
processed["input_ids"] = processed["input_ids"][:-1] |
|
|
processed["attention_mask"] = processed["attention_mask"][:-1] |
|
|
processed["labels"] = processed["labels"][:-1] |
|
|
processed["example_inds"] = processed["example_inds"][:-1] |
|
|
processed["data_ids"] = processed["data_ids"][:-1] |
|
|
|
|
|
return processed |
|
|
|
|
|
def tokenize_messages( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
loss_on_start_token: bool = False, |
|
|
loss_on_eos: bool = False, |
|
|
include_eos: bool = True, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Intended for tokenize from messages to tokenized texts with the loss applied. |
|
|
""" |
|
|
|
|
|
texts = self.messages_to_loss_texts(messages, loss_on_start_token) |
|
|
|
|
|
|
|
|
return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_messages = [ |
|
|
[ |
|
|
{"role": "description", "content": "Pick a number between 1 and 100"}, |
|
|
], |
|
|
|
|
|
[ |
|
|
{"role": "description", "content": "This is a test task"}, |
|
|
{"role": "input", "content": "What is 2+2?"}, |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "input", "content": "What is 3+3?"}, |
|
|
], |
|
|
[ |
|
|
{"role": "description", "content": "This is a test task"}, |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "output", "content": "10"}, |
|
|
{"role": "output", "content": "13"}, |
|
|
], |
|
|
[ |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "output", "content": "10"}, |
|
|
{"role": "output", "content": "13"}, |
|
|
], |
|
|
[ |
|
|
{"role": "input", "content": "What is 2+2?"}, |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "input", "content": "What is 3+3?"}, |
|
|
{"role": "output", "content": "10"}, |
|
|
{"role": "input", "content": "What is 4+4?"}, |
|
|
], |
|
|
] |
|
|
for messages in test_messages: |
|
|
|
|
|
texts = custom_tokenizer.messages_to_loss_texts(messages) |
|
|
|
|
|
print("Texts with loss flags:") |
|
|
for i, text in enumerate(texts): |
|
|
print(f" {i}: {text}") |
|
|
|
|
|
text = custom_tokenizer.messages_to_text(messages, start_generation=True) |
|
|
print(f"\nFull text with generation prompt:") |
|
|
print(text) |
|
|
|
|
|
|
|
|
print("\nTesting save/load cycle:") |
|
|
|
|
|
tokenizer_path = "repos/chat-gemma-tokenizer" |
|
|
custom_tokenizer.save_pretrained(tokenizer_path) |
|
|
print("Tokenizer saved successfully!") |
|
|
|
|
|
|
|
|
import shutil |
|
|
shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py")) |
|
|
print("GemmaChatTokenizer.py saved successfully!") |