explicit-it-v1 / gemma_explicit_tokenizer.py
tsor13's picture
Initial upload of fine‑tuned Gemma + custom tokenizer
1868b83 verified
"""
Custom Gemma Tokenizer for explicit Format
This tokenizer implements the explicit format for message processing:
Format: Uses the standard chat template with proper role labels (user/assistant)
The explicit format uses the model's built-in chat template and includes proper
loss computation flags for training.
To save:
uv run tokenizers/gemma_explicit_tokenizer.py
which will save the tokenizer to the repos/explicit-gemma-tokenizer directory.
mkdir repos/explicit12b
# copy model over
cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_explicit/checkpoint-8/* repos/explicit12b/
# copy tokenizer over
cp repos/explicit-gemma-tokenizer/* repos/explicit12b/
# upload to hf
uv run upload_to_hf.py \
--folder repos/explicit12b \
--repo-id tsor13/explicit12b
"""
from typing import List, Dict, Any, Optional, Union
from transformers import AutoTokenizer
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
import warnings
import difflib
import json
import os
import sys
# # Add parent directory to path to import chat_utils
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from chat_utils import chat_messages_to_text_loss, chat_messages_to_raw_text
CUSTOM_CHAT_TEMPLATE = r"""
{{ bos_token }}{{ '<start_of_turn>description\n' }}
{%- if messages and messages[0]['role'] == 'system' -%}
{%- if messages[0]['content'] is string -%}
{{ messages[0]['content'] | trim }}
{%- else -%}
{{ messages[0]['content'][0]['text'] | trim }}
{%- endif -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
You are a helpful assistant.
{%- set loop_messages = messages -%}
{%- endif -%}
{{ '<end_of_turn>' }}
{# ----- regular turns (input/output) ----- #}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "output" -%}
{%- elif (message['role'] == 'user') -%}
{%- set role = "input" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '\n' }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{{ '<end_of_turn>\n' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{ '<start_of_turn>output\n' }}
{%- endif -%}
""".strip("\n")
# CUSTOM_CHAT_TEMPLATE = r"""
# {# ----- system/description turn ----- #}
# {%- set sys = "" -%}
# {%- set loop_messages = messages -%}
# {%- if messages and messages[0]['role'] == 'system' -%}
# {%- if messages[0]['content'] is string -%}
# {%- set sys = messages[0]['content'] -%}
# {%- elif messages[0]['content'] is iterable -%}
# {# concatenate all text parts #}
# {%- set sys = messages[0]['content'] | selectattr('type','equalto','text') | map(attribute='text') | join('') -%}
# {%- else -%}
# {{ raise_exception("Invalid system content type") }}
# {%- endif -%}
# {%- set loop_messages = messages[1:] -%}
# {%- else -%}
# {%- set sys = "You are a helpful assistant." -%}
# {%- endif -%}
# <start_of_turn>description
# {{ sys | trim }}
# <end_of_turn>
# {# ----- user/assistant turns ----- #}
# {%- for message in loop_messages -%}
# {%- if message['role'] == 'user' -%}
# <start_of_turn>input{{"\n"}}
# {%- if message['content'] is string -%}
# {{ message['content'] | trim }}
# {%- elif message['content'] is iterable -%}
# {%- for item in message['content'] -%}
# {%- if item['type'] == 'text' -%}
# {{ item['text'] | trim }}
# {%- elif item['type'] == 'image' -%}
# <start_of_image>
# {%- endif -%}
# {%- endfor -%}
# {%- else -%}
# {{ raise_exception("Invalid user content type") }}
# {%- endif -%}
# <end_of_turn>{{"\n"}}
# {%- elif message['role'] == 'assistant' -%}
# <start_of_turn>output{{"\n"}}
# {%- if message['content'] is string -%}
# {{ message['content'] | trim }}
# {%- elif message['content'] is iterable -%}
# {%- for item in message['content'] -%}
# {%- if item['type'] == 'text' -%}
# {{ item['text'] | trim }}
# {%- elif item['type'] == 'image' -%}
# <start_of_image>
# {%- endif -%}
# {%- endfor -%}
# {%- else -%}
# {{ raise_exception("Invalid assistant content type") }}
# {%- endif -%}
# <end_of_turn>{{"\n"}}
# {%- else -%}
# {# ignore other roles by default; or raise if you prefer strictness #}
# {# {{ raise_exception("Unsupported role: " ~ message['role']) }} #}
# {%- endif -%}
# {%- endfor -%}
# {%- if add_generation_prompt -%}
# <start_of_turn>output
# {%- endif -%}
# """.strip("\n")
class GemmaExplicitTokenizer(GemmaTokenizerFast):
"""
Custom tokenizer for Gemma models that implements explicit format message processing.
This tokenizer formats messages using the explicit format where:
- Messages use the standard chat template with proper role labels
- Uses the model's built-in chat formatting
- Loss is computed on the assistant/output sections
Attributes:
start_string (str): The starting string used for output generation (depends on tokenizer)
end_string (str): The ending string used for output generation (depends on tokenizer)
"""
def __init__(self, *args, **kwargs):
"""
Initialize the custom tokenizer.
Accepts the same arguments as GemmaTokenizerFast.
"""
super().__init__(*args, **kwargs)
# For explicit format, we use the tokenizer's own chat template
# The start/end strings will be determined by the chat template
self.start_string = "<start_of_turn>"
self.end_string = "<end_of_turn>"
# # Add custom attributes to the tokenizer config for saving/loading
# if not hasattr(self, 'init_kwargs'):
# self.init_kwargs = {}
# self.init_kwargs['start_string'] = self.start_string
# self.init_kwargs['end_string'] = self.end_string
# self.init_kwargs['chat_template'] = CUSTOM_CHAT_TEMPLATE
if not hasattr(self, 'init_kwargs'):
self.init_kwargs = {}
self.init_kwargs['start_string'] = self.start_string
self.init_kwargs['end_string'] = self.end_string
# CRITICAL: set the live attribute so apply_chat_template uses it now
self.chat_template = CUSTOM_CHAT_TEMPLATE
self.init_kwargs['chat_template'] = CUSTOM_CHAT_TEMPLATE
@classmethod
def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Load a tokenizer from a pretrained model or path.
This method ensures our custom class is used instead of the base GemmaTokenizerFast.
"""
# Load the base tokenizer first to get all configuration
base_tokenizer = GemmaTokenizerFast.from_pretrained(
pretrained_model_name_or_path, *args, **kwargs
)
# Create new instance of our custom class by copying the base tokenizer
custom_tokenizer = cls.__new__(cls)
# Copy all attributes from base tokenizer
for attr, value in base_tokenizer.__dict__.items():
setattr(custom_tokenizer, attr, value)
# Initialize our custom attributes for explicit format
custom_tokenizer.start_string = "<start_of_turn>"
custom_tokenizer.end_string = "<end_of_turn>"
# Update init_kwargs to include our custom attributes
if not hasattr(custom_tokenizer, 'init_kwargs'):
custom_tokenizer.init_kwargs = {}
custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
custom_tokenizer.init_kwargs['chat_template'] = CUSTOM_CHAT_TEMPLATE
return custom_tokenizer
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save the tokenizer to a directory, including custom configuration.
"""
# Call parent save method
super().save_pretrained(save_directory, **kwargs)
# Save custom configuration
config_file = os.path.join(save_directory, "tokenizer_config.json")
if os.path.exists(config_file):
with open(config_file, 'r') as f:
config = json.load(f)
else:
config = {}
# Add our custom class info
config["tokenizer_class"] = "GemmaExplicitTokenizer"
config["start_string"] = self.start_string
config["end_string"] = self.end_string
config["chat_template"] = CUSTOM_CHAT_TEMPLATE
# Point to our custom class in the uploaded file
config["auto_map"] = {
"AutoTokenizer": ["gemma_explicit_tokenizer.GemmaExplicitTokenizer", "gemma_explicit_tokenizer.GemmaExplicitTokenizer"]
}
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
def messages_to_loss_texts(
self,
messages: List[Dict[str, Any]],
loss_on_start_token: bool = False,
) -> List[Dict[str, Any]]:
"""
From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
Uses the explicit format matching chat_utils.py.
"""
# FOR NOW, OVERRIDING TO FALSE
loss_on_start_token = False
texts = []
has_description = False
first_output = True
for message in messages:
role = message["role"]
content = message["content"]
if role == "description":
has_description = True
text = f"{self.start_string}{role}\n{content}{self.end_string}\n"
texts.append({"text": text, "compute_loss": False, **message})
elif role == "input":
text = f"{self.start_string}{role}\n{content}{self.end_string}\n"
texts.append({"text": text, "compute_loss": False, **message})
elif role == "output":
if loss_on_start_token:
raise ValueError("Loss on start token is not supported for chat formatters.")
else:
start_text = f"{self.start_string}{role}\n"
texts.append({"text": start_text, "compute_loss": False, **message})
text = f"{content}{self.end_string}"
# Apply conditional loss computation
if first_output and not has_description:
texts.append({"text": text, "compute_loss": False, **message})
else:
texts.append({"text": text, "compute_loss": True, **message})
texts.append({"text": "\n", "compute_loss": False, **message})
first_output = False
else:
raise ValueError(f"Unknown role: {role}. Must be description, input, or output.")
return texts
def messages_to_text(
self,
messages: List[Dict[str, Any]],
start_generation: bool = False,
) -> str:
"""
Messages (description / input / output) to raw text (text).
Uses the explicit format matching chat_utils.py.
"""
texts = self.messages_to_loss_texts(messages)
text = "".join([text["text"] for text in texts])
if start_generation:
text = text + self.start_string + "output\n"
return text
def tokenize_messages(
self,
messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
start_generation: bool = False,
**kwargs,
):
"""
For tokenizing from messages to texts. Supports batching. Good for generation
"""
if isinstance(messages, list) and isinstance(messages[0], list):
# Handle list of lists of messages
all_texts = []
for message_list in messages:
texts = self.messages_to_text(message_list, start_generation)
all_texts.append(texts)
else:
# Handle single list of messages
texts = self.messages_to_text(messages, start_generation)
all_texts = [texts]
# Tokenize all texts
processed = self(text=all_texts, **kwargs)
return processed
def tokenize_loss_texts(
self,
texts: List[Dict[str, Any]],
loss_on_start_token: bool = False,
loss_on_eos: bool = False,
include_eos: bool = True,
):
"""
Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
Needs more complex logic to handle the back and forth labeling.
"""
if loss_on_eos:
raise ValueError("Loss on EOS is not currently supported.")
# Handle single string input
if isinstance(texts, str):
processed = self(text=texts)
# Add EOS token if needed
if (self.eos_token_id is not None and
processed["input_ids"][-1] != self.eos_token_id):
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
processed["attention_mask"] = processed["attention_mask"] + [1]
return processed
# Handle list of text dictionaries
all_processed = []
all_texts = ''
example_inds = []
dataset_inds = []
for i, item in enumerate(texts):
processed = self(text=item["text"])
# Remove BOS token from all but first item
if i != 0 and self.bos_token_id == processed["input_ids"][0]:
processed["input_ids"] = processed["input_ids"][1:]
processed["attention_mask"] = processed["attention_mask"][1:]
# Remove EOS token if present at the end
if processed["input_ids"][-1] == self.eos_token_id:
processed["input_ids"] = processed["input_ids"][:-1]
processed["attention_mask"] = processed["attention_mask"][:-1]
# Check for EOS token in the middle (with special handling for <|im_end|>)
if self.eos_token_id in processed["input_ids"]:
if not self.decode([self.eos_token_id]) == "<|im_end|>":
raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
# Set labels based on compute_loss flag
if item["compute_loss"]:
processed["labels"] = processed["input_ids"].copy()
else:
processed["labels"] = [-100] * len(processed["input_ids"])
# Remove duplicate BOS tokens
if all_processed:
if processed["input_ids"][0] == self.bos_token_id:
processed["input_ids"] = processed["input_ids"][1:]
processed["attention_mask"] = processed["attention_mask"][1:]
processed["labels"] = processed["labels"][1:]
all_processed.append(processed)
all_texts += item["text"]
# Handle example indices
this_num = -1
if 'example_ind' in item.keys():
if item["example_ind"] is not None:
this_num = item["example_ind"]
example_inds.extend([this_num] * len(processed["input_ids"]))
# Handle dataset indices
dataset_ind = -1
if "data_id" in item.keys():
if item["data_id"] is not None:
dataset_ind = item["data_id"]
dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
# Combine all processed results
processed = all_processed[0].copy()
processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
processed["example_inds"] = example_inds
processed["data_ids"] = dataset_inds
# Validate by tokenizing all_texts at once and comparing
processed_all = self(text=all_texts)
if len(processed_all["input_ids"]) != len(processed["input_ids"]):
warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
# Generate diff for debugging
all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
diff_str = "\n".join(diff)
print("Diff between texts:")
print(diff_str)
# Token diff
all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
token_diff_str = "\n".join(token_diff)
print("Diff between tokenized texts:")
print(token_diff_str)
# Add EOS token if needed
if (self.eos_token_id is not None and
processed["input_ids"][-1] != self.eos_token_id):
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
processed["example_inds"] = processed["example_inds"] + [-1]
processed["attention_mask"] = processed["attention_mask"] + [1]
if processed["labels"] is not None:
if loss_on_eos:
processed["labels"] = processed["labels"] + [self.eos_token_id]
else:
processed["labels"] = processed["labels"] + [-100]
if "data_ids" in processed:
processed["data_ids"] = processed["data_ids"] + [-1]
if not include_eos:
# check if EOS token is present
if processed["input_ids"][-1] == self.eos_token_id:
# remove EOS token
processed["input_ids"] = processed["input_ids"][:-1]
processed["attention_mask"] = processed["attention_mask"][:-1]
processed["labels"] = processed["labels"][:-1]
processed["example_inds"] = processed["example_inds"][:-1]
processed["data_ids"] = processed["data_ids"][:-1]
return processed
def tokenize_messages(
self,
messages: List[Dict[str, Any]],
loss_on_start_token: bool = False,
loss_on_eos: bool = False,
include_eos: bool = True,
) -> Dict[str, Any]:
"""
Intended for tokenize from messages to tokenized texts with the loss applied.
"""
# First convert messages to text with loss computation flags
texts = self.messages_to_loss_texts(messages, loss_on_start_token)
# Then tokenize the texts
return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos)
# Register tokenizer classes for AutoTokenizer
AutoTokenizer.register("GemmaExplicitTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaExplicitTokenizer)
if __name__ == "__main__":
# Example usage
# for first load
custom_tokenizer = GemmaExplicitTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
# Test messages in role/content format
test_messages = [
[
{"role": "description", "content": "This is a test task"},
{"role": "input", "content": "What is 2+2?"},
{"role": "output", "content": "4"},
{"role": "input", "content": "What is 3+3?"},
],
[
{"role": "description", "content": "This is a test task"},
{"role": "output", "content": "4"},
{"role": "output", "content": "10"},
{"role": "output", "content": "13"},
],
[
{"role": "output", "content": "4"},
{"role": "output", "content": "10"},
{"role": "output", "content": "13"},
],
[
{"role": "input", "content": "What is 2+2?"},
{"role": "output", "content": "4"},
{"role": "input", "content": "What is 3+3?"},
{"role": "output", "content": "10"},
{"role": "input", "content": "What is 4+4?"},
],
]
for messages in test_messages:
# get messages to text_loss
texts = custom_tokenizer.messages_to_loss_texts(messages)
print("Texts with loss flags:")
for i, text in enumerate(texts):
print(f" {i}: {text}")
text = custom_tokenizer.messages_to_text(messages, start_generation=True)
print(f"\nFull text with generation prompt:")
print(text)
custom_tokenizer.chat_template = CUSTOM_CHAT_TEMPLATE
# test messages in chat forrmat
test_messages = [
[
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
],
]
chat_text = custom_tokenizer.apply_chat_template(test_messages, tokenize=False)[0]
print(f"\nChat text:")
print(chat_text)
custom_tokenizer.chat_template = CUSTOM_CHAT_TEMPLATE
# test messages in chat forrmat
test_messages = [
[
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "What is 4+2?"},
],
]
chat_text = custom_tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)[0]
print(f"\nChat text:")
print(chat_text)
print("\nTesting save/load cycle:")
# Test saving and loading
tokenizer_path = "repos/explicit-gemma-tokenizer"
custom_tokenizer.save_pretrained(tokenizer_path)
print("Tokenizer saved successfully!")
# also save this file in the tokenizer_path
import shutil
shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_explicit_tokenizer.py"))
print("GemmaExplicitTokenizer.py saved successfully!")