|
|
""" |
|
|
Custom Gemma Tokenizer for chat Format |
|
|
|
|
|
This tokenizer implements the chat format for message processing: |
|
|
Format: Uses the standard chat template with proper role labels (user/assistant) |
|
|
|
|
|
The chat format uses the model's built-in chat template and includes proper |
|
|
loss computation flags for training with "assistant" as the generation role. |
|
|
|
|
|
To save: |
|
|
uv run tokenizers/gemma_chat_tokenizer.py |
|
|
which will save the tokenizer to the repos/chat-gemma-tokenizer directory. |
|
|
mkdir repos/chat12b |
|
|
# copy model over |
|
|
cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/ |
|
|
# copy tokenizer over |
|
|
cp repos/chat-gemma-tokenizer/* repos/chat12b/ |
|
|
# upload to hf |
|
|
|
|
|
uv run upload_to_hf.py \ |
|
|
--folder repos/chat12b \ |
|
|
--repo-id tsor13/chat12b |
|
|
""" |
|
|
|
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
from transformers import AutoTokenizer |
|
|
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast |
|
|
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer |
|
|
import warnings |
|
|
import difflib |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
class GemmaChatTokenizer(GemmaTokenizerFast): |
|
|
""" |
|
|
Custom tokenizer for Gemma models that implements chat format message processing. |
|
|
|
|
|
This tokenizer formats messages using the chat format where: |
|
|
- Messages use the standard chat template with proper role labels |
|
|
- Uses the model's built-in chat formatting |
|
|
- Loss is computed on the assistant sections (not output) |
|
|
|
|
|
Attributes: |
|
|
start_string (str): The starting string used for output generation (depends on tokenizer) |
|
|
end_string (str): The ending string used for output generation (depends on tokenizer) |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
""" |
|
|
Initialize the custom tokenizer. |
|
|
|
|
|
Accepts the same arguments as GemmaTokenizerFast. |
|
|
""" |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.start_string = "<start_of_turn>" |
|
|
self.end_string = "<end_of_turn>" |
|
|
|
|
|
|
|
|
if not hasattr(self, 'init_kwargs'): |
|
|
self.init_kwargs = {} |
|
|
self.init_kwargs['start_string'] = self.start_string |
|
|
self.init_kwargs['end_string'] = self.end_string |
|
|
|
|
|
@classmethod |
|
|
def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
|
""" |
|
|
Load a tokenizer from a pretrained model or path. |
|
|
|
|
|
This method ensures our custom class is used instead of the base GemmaTokenizerFast. |
|
|
""" |
|
|
|
|
|
base_tokenizer = GemmaTokenizerFast.from_pretrained( |
|
|
pretrained_model_name_or_path, *args, **kwargs |
|
|
) |
|
|
|
|
|
|
|
|
custom_tokenizer = cls.__new__(cls) |
|
|
|
|
|
|
|
|
for attr, value in base_tokenizer.__dict__.items(): |
|
|
setattr(custom_tokenizer, attr, value) |
|
|
|
|
|
|
|
|
custom_tokenizer.start_string = "<start_of_turn>" |
|
|
custom_tokenizer.end_string = "<end_of_turn>" |
|
|
|
|
|
|
|
|
if not hasattr(custom_tokenizer, 'init_kwargs'): |
|
|
custom_tokenizer.init_kwargs = {} |
|
|
custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string |
|
|
custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string |
|
|
|
|
|
return custom_tokenizer |
|
|
|
|
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): |
|
|
""" |
|
|
Save the tokenizer to a directory, including custom configuration. |
|
|
""" |
|
|
|
|
|
super().save_pretrained(save_directory, **kwargs) |
|
|
|
|
|
|
|
|
config_file = os.path.join(save_directory, "tokenizer_config.json") |
|
|
if os.path.exists(config_file): |
|
|
with open(config_file, 'r') as f: |
|
|
config = json.load(f) |
|
|
else: |
|
|
config = {} |
|
|
|
|
|
|
|
|
config["tokenizer_class"] = "GemmaChatTokenizer" |
|
|
config["start_string"] = self.start_string |
|
|
config["end_string"] = self.end_string |
|
|
|
|
|
config["auto_map"] = { |
|
|
"AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"] |
|
|
} |
|
|
|
|
|
with open(config_file, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
def messages_to_chat_messages( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
start_generation: bool = False, |
|
|
default_user_message: str = "Generate.", |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
From messages (description / input / output) to chat messages (role / content) |
|
|
Uses the same logic as chat_utils.py with system messages. |
|
|
""" |
|
|
chat_prompt = """You are tasked with generating outputs from a particular, potentially stochastic, generative process. You will be given some combination of: |
|
|
- Description: A natural description of the generative process / data distribution |
|
|
- Input: An input on which to condition the generative process. |
|
|
- Example outputs: Example outputs from the process, either in a user message or as prior generations from a chat message. You may assume that any given outputs are exchangeable with one another (order-invariant) and generated from the same process (roughly i.i.d.). If the output data pertains to a single object, it just contains the output. If it contains multiple objects, use json formatting with keys for the name of the output variable. |
|
|
You will be provided at least either a description or an example output. |
|
|
|
|
|
Given these components, your job is to generate JUST the output in your response, roughly approximating the underlying generative process, maintaining any underlying stochasticity (if any is present). If you are asked to generate again, you will either be given an additional input to condition on, or will just be told to "Generate".""" |
|
|
|
|
|
chat_messages = [] |
|
|
system_message = chat_prompt |
|
|
has_description_or_output = False |
|
|
has_input = False |
|
|
|
|
|
for message in messages: |
|
|
if message["role"] == "description": |
|
|
system_message += "\n\nDescription: " + message["content"] |
|
|
chat_messages.append({"role": "system", "content": system_message}) |
|
|
has_description_or_output = True |
|
|
elif message["role"] == "input": |
|
|
has_input = True |
|
|
if not has_description_or_output: |
|
|
system_message += "\n\nExample Input: " + message["content"] |
|
|
else: |
|
|
chat_messages.append({"role": "user", "content": message["content"]}) |
|
|
elif message["role"] == "output": |
|
|
if not has_description_or_output: |
|
|
system_message += "\n\nExample Output: " + message["content"] |
|
|
chat_messages.append({"role": "system", "content": system_message}) |
|
|
has_description_or_output = True |
|
|
else: |
|
|
if not has_input: |
|
|
chat_messages.append({"role": "user", "content": default_user_message}) |
|
|
chat_messages.append({"role": "assistant", "content": message["content"]}) |
|
|
|
|
|
if len(chat_messages) == 0: |
|
|
|
|
|
chat_messages.append({"role": "system", "content": system_message}) |
|
|
|
|
|
chat_messages.append({"role": "user", "content": ""}) |
|
|
if len(chat_messages) == 1: |
|
|
|
|
|
chat_messages.append({"role": "user", "content": ""}) |
|
|
|
|
|
|
|
|
if start_generation and chat_messages[-1]["role"] == "assistant": |
|
|
chat_messages.append({"role": "user", "content": default_user_message}) |
|
|
|
|
|
return chat_messages |
|
|
|
|
|
|
|
|
|
|
|
def messages_to_loss_texts( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
loss_on_start_token: bool = False, |
|
|
default_user_message: str = "Generate.", |
|
|
start_generation: bool = False, |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training. |
|
|
Uses the chat format matching chat_utils.py with updated loss computation logic. |
|
|
""" |
|
|
|
|
|
loss_on_start_token = False |
|
|
|
|
|
texts = [] |
|
|
|
|
|
chat_messages = self.messages_to_chat_messages(messages, start_generation=start_generation, default_user_message=default_user_message) |
|
|
|
|
|
|
|
|
full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation) |
|
|
|
|
|
full_text = full_text.replace("<bos>", "") |
|
|
|
|
|
text_to_split = full_text |
|
|
|
|
|
model_start_text = "<start_of_turn>model\n" |
|
|
first = True |
|
|
while model_start_text in text_to_split: |
|
|
|
|
|
model_start_loc = text_to_split.find(model_start_text) |
|
|
split_ind = model_start_loc + len(model_start_text) |
|
|
text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:] |
|
|
|
|
|
texts.append({"text": text_to_add, "compute_loss": False}) |
|
|
|
|
|
end_string_loc = text_to_split.find(self.end_string) |
|
|
end_ind = end_string_loc + len(self.end_string) |
|
|
text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:] |
|
|
|
|
|
texts.append({"text": text_to_add, "compute_loss": True}) |
|
|
first = False |
|
|
if len(text_to_split) > 0: |
|
|
texts.append({"text": text_to_split, "compute_loss": False}) |
|
|
if len(texts) == 0: |
|
|
breakpoint() |
|
|
|
|
|
return texts |
|
|
|
|
|
def messages_to_text( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
start_generation: bool = False, |
|
|
) -> str: |
|
|
""" |
|
|
Messages (description / input / output) to raw text (text). |
|
|
Uses the chat format matching chat_utils.py. |
|
|
""" |
|
|
texts = self.messages_to_loss_texts(messages, start_generation=start_generation) |
|
|
text = "".join([text["text"] for text in texts]) |
|
|
return text |
|
|
|
|
|
|
|
|
def tokenize_messages( |
|
|
self, |
|
|
messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]], |
|
|
start_generation: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
For tokenizing from messages to texts. Supports batching. Good for generation |
|
|
""" |
|
|
if isinstance(messages, list) and isinstance(messages[0], list): |
|
|
|
|
|
all_texts = [] |
|
|
for message_list in messages: |
|
|
texts = self.messages_to_text(message_list, start_generation) |
|
|
all_texts.append(texts) |
|
|
else: |
|
|
|
|
|
texts = self.messages_to_text(messages, start_generation) |
|
|
all_texts = [texts] |
|
|
|
|
|
|
|
|
processed = self(text=all_texts, **kwargs) |
|
|
return processed |
|
|
|
|
|
|
|
|
def tokenize_loss_texts( |
|
|
self, |
|
|
texts: List[Dict[str, Any]], |
|
|
loss_on_start_token: bool = False, |
|
|
loss_on_eos: bool = False, |
|
|
include_eos: bool = True, |
|
|
): |
|
|
""" |
|
|
Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels). |
|
|
|
|
|
Needs more complex logic to handle the back and forth labeling. |
|
|
""" |
|
|
if loss_on_eos: |
|
|
raise ValueError("Loss on EOS is not currently supported.") |
|
|
|
|
|
|
|
|
if isinstance(texts, str): |
|
|
processed = self(text=texts) |
|
|
|
|
|
if (self.eos_token_id is not None and |
|
|
processed["input_ids"][-1] != self.eos_token_id): |
|
|
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] |
|
|
processed["attention_mask"] = processed["attention_mask"] + [1] |
|
|
return processed |
|
|
|
|
|
|
|
|
all_processed = [] |
|
|
all_texts = '' |
|
|
example_inds = [] |
|
|
dataset_inds = [] |
|
|
|
|
|
|
|
|
for i, item in enumerate(texts): |
|
|
processed = self(text=item["text"]) |
|
|
|
|
|
|
|
|
if i != 0 and self.bos_token_id == processed["input_ids"][0]: |
|
|
processed["input_ids"] = processed["input_ids"][1:] |
|
|
processed["attention_mask"] = processed["attention_mask"][1:] |
|
|
|
|
|
|
|
|
if processed["input_ids"][-1] == self.eos_token_id: |
|
|
processed["input_ids"] = processed["input_ids"][:-1] |
|
|
processed["attention_mask"] = processed["attention_mask"][:-1] |
|
|
|
|
|
|
|
|
if self.eos_token_id in processed["input_ids"]: |
|
|
if not self.decode([self.eos_token_id]) == "<|im_end|>": |
|
|
raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.") |
|
|
|
|
|
|
|
|
if item["compute_loss"]: |
|
|
processed["labels"] = processed["input_ids"].copy() |
|
|
else: |
|
|
processed["labels"] = [-100] * len(processed["input_ids"]) |
|
|
|
|
|
|
|
|
if all_processed: |
|
|
if processed["input_ids"][0] == self.bos_token_id: |
|
|
processed["input_ids"] = processed["input_ids"][1:] |
|
|
processed["attention_mask"] = processed["attention_mask"][1:] |
|
|
processed["labels"] = processed["labels"][1:] |
|
|
|
|
|
all_processed.append(processed) |
|
|
all_texts += item["text"] |
|
|
|
|
|
|
|
|
this_num = -1 |
|
|
if 'example_ind' in item.keys(): |
|
|
if item["example_ind"] is not None: |
|
|
this_num = item["example_ind"] |
|
|
example_inds.extend([this_num] * len(processed["input_ids"])) |
|
|
|
|
|
|
|
|
dataset_ind = -1 |
|
|
if "data_id" in item.keys(): |
|
|
if item["data_id"] is not None: |
|
|
dataset_ind = item["data_id"] |
|
|
dataset_inds.extend([dataset_ind] * len(processed["input_ids"])) |
|
|
|
|
|
|
|
|
processed = all_processed[0].copy() |
|
|
processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist] |
|
|
processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist] |
|
|
processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist] |
|
|
processed["example_inds"] = example_inds |
|
|
processed["data_ids"] = dataset_inds |
|
|
|
|
|
|
|
|
processed_all = self(text=all_texts) |
|
|
if len(processed_all["input_ids"]) != len(processed["input_ids"]): |
|
|
warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}") |
|
|
|
|
|
|
|
|
all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False) |
|
|
processed_text = self.decode(processed["input_ids"], skip_special_tokens=False) |
|
|
|
|
|
diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines()) |
|
|
diff_str = "\n".join(diff) |
|
|
print("Diff between texts:") |
|
|
print(diff_str) |
|
|
|
|
|
|
|
|
all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]]) |
|
|
processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]]) |
|
|
token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines()) |
|
|
token_diff_str = "\n".join(token_diff) |
|
|
print("Diff between tokenized texts:") |
|
|
print(token_diff_str) |
|
|
|
|
|
|
|
|
if (self.eos_token_id is not None and |
|
|
processed["input_ids"][-1] != self.eos_token_id): |
|
|
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id] |
|
|
processed["example_inds"] = processed["example_inds"] + [-1] |
|
|
processed["attention_mask"] = processed["attention_mask"] + [1] |
|
|
if processed["labels"] is not None: |
|
|
if loss_on_eos: |
|
|
processed["labels"] = processed["labels"] + [self.eos_token_id] |
|
|
else: |
|
|
processed["labels"] = processed["labels"] + [-100] |
|
|
if "data_ids" in processed: |
|
|
processed["data_ids"] = processed["data_ids"] + [-1] |
|
|
|
|
|
if not include_eos: |
|
|
|
|
|
if processed["input_ids"][-1] == self.eos_token_id: |
|
|
|
|
|
processed["input_ids"] = processed["input_ids"][:-1] |
|
|
processed["attention_mask"] = processed["attention_mask"][:-1] |
|
|
processed["labels"] = processed["labels"][:-1] |
|
|
processed["example_inds"] = processed["example_inds"][:-1] |
|
|
processed["data_ids"] = processed["data_ids"][:-1] |
|
|
|
|
|
return processed |
|
|
|
|
|
def tokenize_messages( |
|
|
self, |
|
|
messages: List[Dict[str, Any]], |
|
|
loss_on_start_token: bool = False, |
|
|
loss_on_eos: bool = False, |
|
|
include_eos: bool = True, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Intended for tokenize from messages to tokenized texts with the loss applied. |
|
|
""" |
|
|
|
|
|
texts = self.messages_to_loss_texts(messages, loss_on_start_token) |
|
|
|
|
|
|
|
|
return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_messages = [ |
|
|
[ |
|
|
{"role": "description", "content": "Pick a number between 1 and 100"}, |
|
|
], |
|
|
|
|
|
[ |
|
|
{"role": "description", "content": "This is a test task"}, |
|
|
{"role": "input", "content": "What is 2+2?"}, |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "input", "content": "What is 3+3?"}, |
|
|
], |
|
|
[ |
|
|
{"role": "description", "content": "This is a test task"}, |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "output", "content": "10"}, |
|
|
{"role": "output", "content": "13"}, |
|
|
], |
|
|
[ |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "output", "content": "10"}, |
|
|
{"role": "output", "content": "13"}, |
|
|
], |
|
|
[ |
|
|
{"role": "input", "content": "What is 2+2?"}, |
|
|
{"role": "output", "content": "4"}, |
|
|
{"role": "input", "content": "What is 3+3?"}, |
|
|
{"role": "output", "content": "10"}, |
|
|
{"role": "input", "content": "What is 4+4?"}, |
|
|
], |
|
|
[ |
|
|
{"role": "description", "content": "DESCRIPTION"}, |
|
|
{"role": "input", "content": "INPUT1"}, |
|
|
{"role": "output", "content": "OUTPUT1"}, |
|
|
{"role": "input", "content": "INPUT2"}, |
|
|
{"role": "output", "content": "OUTPUT2"}, |
|
|
], |
|
|
[ |
|
|
{"role": "description", "content": "DESCRIPTION"}, |
|
|
{"role": "output", "content": "OUTPUT1"}, |
|
|
{"role": "output", "content": "OUTPUT2"}, |
|
|
], |
|
|
] |
|
|
for messages in test_messages: |
|
|
|
|
|
texts = custom_tokenizer.messages_to_loss_texts(messages) |
|
|
|
|
|
print("Texts with loss flags:") |
|
|
for i, text in enumerate(texts): |
|
|
print(f" {i}: {text}") |
|
|
|
|
|
text = custom_tokenizer.messages_to_text(messages, start_generation=True) |
|
|
print(f"\nFull text with generation prompt:") |
|
|
print(text) |
|
|
|
|
|
|
|
|
print("\nTesting save/load cycle:") |
|
|
|
|
|
tokenizer_path = "repos/chat-gemma-tokenizer" |
|
|
custom_tokenizer.save_pretrained(tokenizer_path) |
|
|
print("Tokenizer saved successfully!") |
|
|
|
|
|
|
|
|
import shutil |
|
|
shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py")) |
|
|
print("GemmaChatTokenizer.py saved successfully!") |