Initial upload of fine‑tuned Gemma + custom tokenizer
Browse files- .gitattributes +1 -0
- added_tokens.json +3 -0
- config.json +62 -0
- gemma_chat_tokenizer.py +508 -0
- generation_config.json +13 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- special_tokens_map.json +33 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
- training_args.bin +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image_soft_token>": 262144
|
| 3 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Gemma3ForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"boi_token_index": 255999,
|
| 6 |
+
"eoi_token_index": 256000,
|
| 7 |
+
"eos_token_id": [
|
| 8 |
+
1,
|
| 9 |
+
106
|
| 10 |
+
],
|
| 11 |
+
"image_token_index": 262144,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mm_tokens_per_image": 256,
|
| 14 |
+
"model_type": "gemma3",
|
| 15 |
+
"text_config": {
|
| 16 |
+
"attention_bias": false,
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"attn_logit_softcapping": null,
|
| 19 |
+
"cache_implementation": "hybrid",
|
| 20 |
+
"final_logit_softcapping": null,
|
| 21 |
+
"head_dim": 256,
|
| 22 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 23 |
+
"hidden_size": 3840,
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"intermediate_size": 15360,
|
| 26 |
+
"max_position_embeddings": 131072,
|
| 27 |
+
"model_type": "gemma3_text",
|
| 28 |
+
"num_attention_heads": 16,
|
| 29 |
+
"num_hidden_layers": 48,
|
| 30 |
+
"num_key_value_heads": 8,
|
| 31 |
+
"query_pre_attn_scalar": 256,
|
| 32 |
+
"rms_norm_eps": 1e-06,
|
| 33 |
+
"rope_local_base_freq": 10000.0,
|
| 34 |
+
"rope_scaling": {
|
| 35 |
+
"factor": 8.0,
|
| 36 |
+
"rope_type": "linear"
|
| 37 |
+
},
|
| 38 |
+
"rope_theta": 1000000.0,
|
| 39 |
+
"sliding_window": 1024,
|
| 40 |
+
"sliding_window_pattern": 6,
|
| 41 |
+
"torch_dtype": "float32",
|
| 42 |
+
"use_cache": true,
|
| 43 |
+
"vocab_size": 262208
|
| 44 |
+
},
|
| 45 |
+
"torch_dtype": "bfloat16",
|
| 46 |
+
"transformers_version": "4.51.3",
|
| 47 |
+
"vision_config": {
|
| 48 |
+
"attention_dropout": 0.0,
|
| 49 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 50 |
+
"hidden_size": 1152,
|
| 51 |
+
"image_size": 896,
|
| 52 |
+
"intermediate_size": 4304,
|
| 53 |
+
"layer_norm_eps": 1e-06,
|
| 54 |
+
"model_type": "siglip_vision_model",
|
| 55 |
+
"num_attention_heads": 16,
|
| 56 |
+
"num_channels": 3,
|
| 57 |
+
"num_hidden_layers": 27,
|
| 58 |
+
"patch_size": 14,
|
| 59 |
+
"torch_dtype": "float32",
|
| 60 |
+
"vision_use_head": false
|
| 61 |
+
}
|
| 62 |
+
}
|
gemma_chat_tokenizer.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Gemma Tokenizer for chat Format
|
| 3 |
+
|
| 4 |
+
This tokenizer implements the chat format for message processing:
|
| 5 |
+
Format: Uses the standard chat template with proper role labels (user/assistant)
|
| 6 |
+
|
| 7 |
+
The chat format uses the model's built-in chat template and includes proper
|
| 8 |
+
loss computation flags for training with "assistant" as the generation role.
|
| 9 |
+
|
| 10 |
+
To save:
|
| 11 |
+
uv run tokenizers/gemma_chat_tokenizer.py
|
| 12 |
+
which will save the tokenizer to the repos/chat-gemma-tokenizer directory.
|
| 13 |
+
mkdir repos/chat12b
|
| 14 |
+
# copy model over
|
| 15 |
+
cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/
|
| 16 |
+
# copy tokenizer over
|
| 17 |
+
cp repos/chat-gemma-tokenizer/* repos/chat12b/
|
| 18 |
+
# upload to hf
|
| 19 |
+
|
| 20 |
+
uv run upload_to_hf.py \
|
| 21 |
+
--folder repos/chat12b \
|
| 22 |
+
--repo-id tsor13/chat12b
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from typing import List, Dict, Any, Optional, Union
|
| 26 |
+
from transformers import AutoTokenizer
|
| 27 |
+
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
|
| 28 |
+
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
|
| 29 |
+
import warnings
|
| 30 |
+
import difflib
|
| 31 |
+
import json
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GemmaChatTokenizer(GemmaTokenizerFast):
|
| 37 |
+
"""
|
| 38 |
+
Custom tokenizer for Gemma models that implements chat format message processing.
|
| 39 |
+
|
| 40 |
+
This tokenizer formats messages using the chat format where:
|
| 41 |
+
- Messages use the standard chat template with proper role labels
|
| 42 |
+
- Uses the model's built-in chat formatting
|
| 43 |
+
- Loss is computed on the assistant sections (not output)
|
| 44 |
+
|
| 45 |
+
Attributes:
|
| 46 |
+
start_string (str): The starting string used for output generation (depends on tokenizer)
|
| 47 |
+
end_string (str): The ending string used for output generation (depends on tokenizer)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
"""
|
| 52 |
+
Initialize the custom tokenizer.
|
| 53 |
+
|
| 54 |
+
Accepts the same arguments as GemmaTokenizerFast.
|
| 55 |
+
"""
|
| 56 |
+
super().__init__(*args, **kwargs)
|
| 57 |
+
|
| 58 |
+
# For chat format, we use the tokenizer's own chat template
|
| 59 |
+
# The start/end strings will be determined by the chat template
|
| 60 |
+
self.start_string = "<start_of_turn>" # Will be set dynamically
|
| 61 |
+
self.end_string = "<end_of_turn>" # Will be set dynamically
|
| 62 |
+
|
| 63 |
+
# Add custom attributes to the tokenizer config for saving/loading
|
| 64 |
+
if not hasattr(self, 'init_kwargs'):
|
| 65 |
+
self.init_kwargs = {}
|
| 66 |
+
self.init_kwargs['start_string'] = self.start_string
|
| 67 |
+
self.init_kwargs['end_string'] = self.end_string
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 71 |
+
"""
|
| 72 |
+
Load a tokenizer from a pretrained model or path.
|
| 73 |
+
|
| 74 |
+
This method ensures our custom class is used instead of the base GemmaTokenizerFast.
|
| 75 |
+
"""
|
| 76 |
+
# Load the base tokenizer first to get all configuration
|
| 77 |
+
base_tokenizer = GemmaTokenizerFast.from_pretrained(
|
| 78 |
+
pretrained_model_name_or_path, *args, **kwargs
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Create new instance of our custom class by copying the base tokenizer
|
| 82 |
+
custom_tokenizer = cls.__new__(cls)
|
| 83 |
+
|
| 84 |
+
# Copy all attributes from base tokenizer
|
| 85 |
+
for attr, value in base_tokenizer.__dict__.items():
|
| 86 |
+
setattr(custom_tokenizer, attr, value)
|
| 87 |
+
|
| 88 |
+
# Initialize our custom attributes for chat format
|
| 89 |
+
custom_tokenizer.start_string = "<start_of_turn>"
|
| 90 |
+
custom_tokenizer.end_string = "<end_of_turn>"
|
| 91 |
+
|
| 92 |
+
# Update init_kwargs to include our custom attributes
|
| 93 |
+
if not hasattr(custom_tokenizer, 'init_kwargs'):
|
| 94 |
+
custom_tokenizer.init_kwargs = {}
|
| 95 |
+
custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
|
| 96 |
+
custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
|
| 97 |
+
|
| 98 |
+
return custom_tokenizer
|
| 99 |
+
|
| 100 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 101 |
+
"""
|
| 102 |
+
Save the tokenizer to a directory, including custom configuration.
|
| 103 |
+
"""
|
| 104 |
+
# Call parent save method
|
| 105 |
+
super().save_pretrained(save_directory, **kwargs)
|
| 106 |
+
|
| 107 |
+
# Save custom configuration
|
| 108 |
+
config_file = os.path.join(save_directory, "tokenizer_config.json")
|
| 109 |
+
if os.path.exists(config_file):
|
| 110 |
+
with open(config_file, 'r') as f:
|
| 111 |
+
config = json.load(f)
|
| 112 |
+
else:
|
| 113 |
+
config = {}
|
| 114 |
+
|
| 115 |
+
# Add our custom class info
|
| 116 |
+
config["tokenizer_class"] = "GemmaChatTokenizer"
|
| 117 |
+
config["start_string"] = self.start_string
|
| 118 |
+
config["end_string"] = self.end_string
|
| 119 |
+
# Point to our custom class in the uploaded file
|
| 120 |
+
config["auto_map"] = {
|
| 121 |
+
"AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"]
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
with open(config_file, 'w') as f:
|
| 125 |
+
json.dump(config, f, indent=2)
|
| 126 |
+
|
| 127 |
+
def messages_to_chat_messages(
|
| 128 |
+
self,
|
| 129 |
+
messages: List[Dict[str, Any]],
|
| 130 |
+
start_generation: bool = False,
|
| 131 |
+
default_user_message: str = "Generate.",
|
| 132 |
+
) -> List[Dict[str, Any]]:
|
| 133 |
+
"""
|
| 134 |
+
From messages (description / input / output) to chat messages (role / content)
|
| 135 |
+
Uses the same logic as chat_utils.py with system messages.
|
| 136 |
+
"""
|
| 137 |
+
chat_prompt = """You are tasked with generating outputs from a particular, potentially stochastic, generative process. You will be given some combination of:
|
| 138 |
+
- Description: A natural description of the generative process / data distribution
|
| 139 |
+
- Input: An input on which to condition the generative process.
|
| 140 |
+
- Example outputs: Example outputs from the process, either in a user message or as prior generations from a chat message. You may assume that any given outputs are exchangeable with one another (order-invariant) and generated from the same process (roughly i.i.d.). If the output data pertains to a single object, it just contains the output. If it contains multiple objects, use json formatting with keys for the name of the output variable.
|
| 141 |
+
You will be provided at least either a description or an example output.
|
| 142 |
+
|
| 143 |
+
Given these components, your job is to generate JUST the output in your response, roughly approximating the underlying generative process, maintaining any underlying stochasticity (if any is present). If you are asked to generate again, you will either be given an additional input to condition on, or will just be told to "Generate"."""
|
| 144 |
+
|
| 145 |
+
chat_messages = []
|
| 146 |
+
system_message = chat_prompt
|
| 147 |
+
has_description_or_output = False
|
| 148 |
+
has_input = False
|
| 149 |
+
|
| 150 |
+
for message in messages:
|
| 151 |
+
if message["role"] == "description":
|
| 152 |
+
system_message += "\n\nDescription: " + message["content"]
|
| 153 |
+
chat_messages.append({"role": "system", "content": system_message})
|
| 154 |
+
has_description_or_output = True
|
| 155 |
+
elif message["role"] == "input":
|
| 156 |
+
has_input = True
|
| 157 |
+
if not has_description_or_output:
|
| 158 |
+
system_message += "\n\nExample Input: " + message["content"]
|
| 159 |
+
else:
|
| 160 |
+
chat_messages.append({"role": "user", "content": message["content"]})
|
| 161 |
+
elif message["role"] == "output":
|
| 162 |
+
if not has_description_or_output:
|
| 163 |
+
system_message += "\n\nExample Output: " + message["content"]
|
| 164 |
+
chat_messages.append({"role": "system", "content": system_message})
|
| 165 |
+
has_description_or_output = True
|
| 166 |
+
else:
|
| 167 |
+
if not has_input:
|
| 168 |
+
chat_messages.append({"role": "user", "content": default_user_message})
|
| 169 |
+
chat_messages.append({"role": "assistant", "content": message["content"]})
|
| 170 |
+
|
| 171 |
+
if len(chat_messages) == 0:
|
| 172 |
+
# add system message
|
| 173 |
+
chat_messages.append({"role": "system", "content": system_message})
|
| 174 |
+
# also add in empty user message for now for gemma
|
| 175 |
+
chat_messages.append({"role": "user", "content": ""})
|
| 176 |
+
if len(chat_messages) == 1:
|
| 177 |
+
# add in empty user message for now for gemma
|
| 178 |
+
chat_messages.append({"role": "user", "content": ""})
|
| 179 |
+
|
| 180 |
+
# if last message is output and start_generation is true, add a default user message
|
| 181 |
+
if start_generation and chat_messages[-1]["role"] == "assistant":
|
| 182 |
+
chat_messages.append({"role": "user", "content": default_user_message})
|
| 183 |
+
|
| 184 |
+
return chat_messages
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def messages_to_loss_texts(
|
| 189 |
+
self,
|
| 190 |
+
messages: List[Dict[str, Any]],
|
| 191 |
+
loss_on_start_token: bool = False,
|
| 192 |
+
default_user_message: str = "Generate.",
|
| 193 |
+
start_generation: bool = False,
|
| 194 |
+
) -> List[Dict[str, Any]]:
|
| 195 |
+
"""
|
| 196 |
+
From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
|
| 197 |
+
Uses the chat format matching chat_utils.py with updated loss computation logic.
|
| 198 |
+
"""
|
| 199 |
+
# FOR NOW, OVERRIDING TO FALSE
|
| 200 |
+
loss_on_start_token = False
|
| 201 |
+
|
| 202 |
+
texts = []
|
| 203 |
+
|
| 204 |
+
chat_messages = self.messages_to_chat_messages(messages, start_generation=start_generation, default_user_message=default_user_message)
|
| 205 |
+
|
| 206 |
+
# Apply chat template
|
| 207 |
+
full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation)
|
| 208 |
+
# replace <bos> with nothing
|
| 209 |
+
full_text = full_text.replace("<bos>", "")
|
| 210 |
+
|
| 211 |
+
text_to_split = full_text
|
| 212 |
+
# now, find all places starting with <start_of_turn>model\n
|
| 213 |
+
model_start_text = "<start_of_turn>model\n" # TODO - manual for now, change later
|
| 214 |
+
first = True
|
| 215 |
+
while model_start_text in text_to_split:
|
| 216 |
+
# get location of model_start_text
|
| 217 |
+
model_start_loc = text_to_split.find(model_start_text)
|
| 218 |
+
split_ind = model_start_loc + len(model_start_text)
|
| 219 |
+
text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:]
|
| 220 |
+
# add to texts
|
| 221 |
+
texts.append({"text": text_to_add, "compute_loss": False})
|
| 222 |
+
# get location of end_string
|
| 223 |
+
end_string_loc = text_to_split.find(self.end_string)
|
| 224 |
+
end_ind = end_string_loc + len(self.end_string)
|
| 225 |
+
text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:]
|
| 226 |
+
# Calculate loss on ALL assistant messages (removed conditional logic)
|
| 227 |
+
texts.append({"text": text_to_add, "compute_loss": True})
|
| 228 |
+
first = False
|
| 229 |
+
if len(text_to_split) > 0:
|
| 230 |
+
texts.append({"text": text_to_split, "compute_loss": False})
|
| 231 |
+
if len(texts) == 0:
|
| 232 |
+
breakpoint()
|
| 233 |
+
|
| 234 |
+
return texts
|
| 235 |
+
|
| 236 |
+
def messages_to_text(
|
| 237 |
+
self,
|
| 238 |
+
messages: List[Dict[str, Any]],
|
| 239 |
+
start_generation: bool = False,
|
| 240 |
+
) -> str:
|
| 241 |
+
"""
|
| 242 |
+
Messages (description / input / output) to raw text (text).
|
| 243 |
+
Uses the chat format matching chat_utils.py.
|
| 244 |
+
"""
|
| 245 |
+
texts = self.messages_to_loss_texts(messages, start_generation=start_generation)
|
| 246 |
+
text = "".join([text["text"] for text in texts])
|
| 247 |
+
return text
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def tokenize_messages(
|
| 251 |
+
self,
|
| 252 |
+
messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
|
| 253 |
+
start_generation: bool = False,
|
| 254 |
+
**kwargs,
|
| 255 |
+
):
|
| 256 |
+
"""
|
| 257 |
+
For tokenizing from messages to texts. Supports batching. Good for generation
|
| 258 |
+
"""
|
| 259 |
+
if isinstance(messages, list) and isinstance(messages[0], list):
|
| 260 |
+
# Handle list of lists of messages
|
| 261 |
+
all_texts = []
|
| 262 |
+
for message_list in messages:
|
| 263 |
+
texts = self.messages_to_text(message_list, start_generation)
|
| 264 |
+
all_texts.append(texts)
|
| 265 |
+
else:
|
| 266 |
+
# Handle single list of messages
|
| 267 |
+
texts = self.messages_to_text(messages, start_generation)
|
| 268 |
+
all_texts = [texts]
|
| 269 |
+
|
| 270 |
+
# Tokenize all texts
|
| 271 |
+
processed = self(text=all_texts, **kwargs)
|
| 272 |
+
return processed
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def tokenize_loss_texts(
|
| 276 |
+
self,
|
| 277 |
+
texts: List[Dict[str, Any]],
|
| 278 |
+
loss_on_start_token: bool = False,
|
| 279 |
+
loss_on_eos: bool = False,
|
| 280 |
+
include_eos: bool = True,
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
|
| 284 |
+
|
| 285 |
+
Needs more complex logic to handle the back and forth labeling.
|
| 286 |
+
"""
|
| 287 |
+
if loss_on_eos:
|
| 288 |
+
raise ValueError("Loss on EOS is not currently supported.")
|
| 289 |
+
|
| 290 |
+
# Handle single string input
|
| 291 |
+
if isinstance(texts, str):
|
| 292 |
+
processed = self(text=texts)
|
| 293 |
+
# Add EOS token if needed
|
| 294 |
+
if (self.eos_token_id is not None and
|
| 295 |
+
processed["input_ids"][-1] != self.eos_token_id):
|
| 296 |
+
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
|
| 297 |
+
processed["attention_mask"] = processed["attention_mask"] + [1]
|
| 298 |
+
return processed
|
| 299 |
+
|
| 300 |
+
# Handle list of text dictionaries
|
| 301 |
+
all_processed = []
|
| 302 |
+
all_texts = ''
|
| 303 |
+
example_inds = []
|
| 304 |
+
dataset_inds = []
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
for i, item in enumerate(texts):
|
| 308 |
+
processed = self(text=item["text"])
|
| 309 |
+
|
| 310 |
+
# Remove BOS token from all but first item
|
| 311 |
+
if i != 0 and self.bos_token_id == processed["input_ids"][0]:
|
| 312 |
+
processed["input_ids"] = processed["input_ids"][1:]
|
| 313 |
+
processed["attention_mask"] = processed["attention_mask"][1:]
|
| 314 |
+
|
| 315 |
+
# Remove EOS token if present at the end
|
| 316 |
+
if processed["input_ids"][-1] == self.eos_token_id:
|
| 317 |
+
processed["input_ids"] = processed["input_ids"][:-1]
|
| 318 |
+
processed["attention_mask"] = processed["attention_mask"][:-1]
|
| 319 |
+
|
| 320 |
+
# Check for EOS token in the middle (with special handling for <|im_end|>)
|
| 321 |
+
if self.eos_token_id in processed["input_ids"]:
|
| 322 |
+
if not self.decode([self.eos_token_id]) == "<|im_end|>":
|
| 323 |
+
raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
|
| 324 |
+
|
| 325 |
+
# Set labels based on compute_loss flag
|
| 326 |
+
if item["compute_loss"]:
|
| 327 |
+
processed["labels"] = processed["input_ids"].copy()
|
| 328 |
+
else:
|
| 329 |
+
processed["labels"] = [-100] * len(processed["input_ids"])
|
| 330 |
+
|
| 331 |
+
# Remove duplicate BOS tokens
|
| 332 |
+
if all_processed:
|
| 333 |
+
if processed["input_ids"][0] == self.bos_token_id:
|
| 334 |
+
processed["input_ids"] = processed["input_ids"][1:]
|
| 335 |
+
processed["attention_mask"] = processed["attention_mask"][1:]
|
| 336 |
+
processed["labels"] = processed["labels"][1:]
|
| 337 |
+
|
| 338 |
+
all_processed.append(processed)
|
| 339 |
+
all_texts += item["text"]
|
| 340 |
+
|
| 341 |
+
# Handle example indices
|
| 342 |
+
this_num = -1
|
| 343 |
+
if 'example_ind' in item.keys():
|
| 344 |
+
if item["example_ind"] is not None:
|
| 345 |
+
this_num = item["example_ind"]
|
| 346 |
+
example_inds.extend([this_num] * len(processed["input_ids"]))
|
| 347 |
+
|
| 348 |
+
# Handle dataset indices
|
| 349 |
+
dataset_ind = -1
|
| 350 |
+
if "data_id" in item.keys():
|
| 351 |
+
if item["data_id"] is not None:
|
| 352 |
+
dataset_ind = item["data_id"]
|
| 353 |
+
dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
|
| 354 |
+
|
| 355 |
+
# Combine all processed results
|
| 356 |
+
processed = all_processed[0].copy()
|
| 357 |
+
processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
|
| 358 |
+
processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
|
| 359 |
+
processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
|
| 360 |
+
processed["example_inds"] = example_inds
|
| 361 |
+
processed["data_ids"] = dataset_inds
|
| 362 |
+
|
| 363 |
+
# Validate by tokenizing all_texts at once and comparing
|
| 364 |
+
processed_all = self(text=all_texts)
|
| 365 |
+
if len(processed_all["input_ids"]) != len(processed["input_ids"]):
|
| 366 |
+
warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
|
| 367 |
+
|
| 368 |
+
# Generate diff for debugging
|
| 369 |
+
all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
|
| 370 |
+
processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
|
| 371 |
+
|
| 372 |
+
diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
|
| 373 |
+
diff_str = "\n".join(diff)
|
| 374 |
+
print("Diff between texts:")
|
| 375 |
+
print(diff_str)
|
| 376 |
+
|
| 377 |
+
# Token diff
|
| 378 |
+
all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
|
| 379 |
+
processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
|
| 380 |
+
token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
|
| 381 |
+
token_diff_str = "\n".join(token_diff)
|
| 382 |
+
print("Diff between tokenized texts:")
|
| 383 |
+
print(token_diff_str)
|
| 384 |
+
|
| 385 |
+
# Add EOS token if needed
|
| 386 |
+
if (self.eos_token_id is not None and
|
| 387 |
+
processed["input_ids"][-1] != self.eos_token_id):
|
| 388 |
+
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
|
| 389 |
+
processed["example_inds"] = processed["example_inds"] + [-1]
|
| 390 |
+
processed["attention_mask"] = processed["attention_mask"] + [1]
|
| 391 |
+
if processed["labels"] is not None:
|
| 392 |
+
if loss_on_eos:
|
| 393 |
+
processed["labels"] = processed["labels"] + [self.eos_token_id]
|
| 394 |
+
else:
|
| 395 |
+
processed["labels"] = processed["labels"] + [-100]
|
| 396 |
+
if "data_ids" in processed:
|
| 397 |
+
processed["data_ids"] = processed["data_ids"] + [-1]
|
| 398 |
+
|
| 399 |
+
if not include_eos:
|
| 400 |
+
# check if EOS token is present
|
| 401 |
+
if processed["input_ids"][-1] == self.eos_token_id:
|
| 402 |
+
# remove EOS token
|
| 403 |
+
processed["input_ids"] = processed["input_ids"][:-1]
|
| 404 |
+
processed["attention_mask"] = processed["attention_mask"][:-1]
|
| 405 |
+
processed["labels"] = processed["labels"][:-1]
|
| 406 |
+
processed["example_inds"] = processed["example_inds"][:-1]
|
| 407 |
+
processed["data_ids"] = processed["data_ids"][:-1]
|
| 408 |
+
|
| 409 |
+
return processed
|
| 410 |
+
|
| 411 |
+
def tokenize_messages(
|
| 412 |
+
self,
|
| 413 |
+
messages: List[Dict[str, Any]],
|
| 414 |
+
loss_on_start_token: bool = False,
|
| 415 |
+
loss_on_eos: bool = False,
|
| 416 |
+
include_eos: bool = True,
|
| 417 |
+
) -> Dict[str, Any]:
|
| 418 |
+
"""
|
| 419 |
+
Intended for tokenize from messages to tokenized texts with the loss applied.
|
| 420 |
+
"""
|
| 421 |
+
# First convert messages to text with loss computation flags
|
| 422 |
+
texts = self.messages_to_loss_texts(messages, loss_on_start_token)
|
| 423 |
+
|
| 424 |
+
# Then tokenize the texts
|
| 425 |
+
return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# Register tokenizer classes for AutoTokenizer
|
| 431 |
+
AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
if __name__ == "__main__":
|
| 435 |
+
# Example usage
|
| 436 |
+
# for first load
|
| 437 |
+
custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
|
| 438 |
+
|
| 439 |
+
# for subsequent loads
|
| 440 |
+
# custom_tokenizer = GemmaChatTokenizer.from_pretrained("tsor13/chat-gemma-12b-pt")
|
| 441 |
+
# custom_tokenizer = GemmaChatTokenizer.from_pretrained("repos/chat-gemma-12b-pt")
|
| 442 |
+
|
| 443 |
+
# Test messages in role/content format
|
| 444 |
+
test_messages = [
|
| 445 |
+
[
|
| 446 |
+
{"role": "description", "content": "Pick a number between 1 and 100"},
|
| 447 |
+
],
|
| 448 |
+
|
| 449 |
+
[
|
| 450 |
+
{"role": "description", "content": "This is a test task"},
|
| 451 |
+
{"role": "input", "content": "What is 2+2?"},
|
| 452 |
+
{"role": "output", "content": "4"},
|
| 453 |
+
{"role": "input", "content": "What is 3+3?"},
|
| 454 |
+
],
|
| 455 |
+
[
|
| 456 |
+
{"role": "description", "content": "This is a test task"},
|
| 457 |
+
{"role": "output", "content": "4"},
|
| 458 |
+
{"role": "output", "content": "10"},
|
| 459 |
+
{"role": "output", "content": "13"},
|
| 460 |
+
],
|
| 461 |
+
[
|
| 462 |
+
{"role": "output", "content": "4"},
|
| 463 |
+
{"role": "output", "content": "10"},
|
| 464 |
+
{"role": "output", "content": "13"},
|
| 465 |
+
],
|
| 466 |
+
[
|
| 467 |
+
{"role": "input", "content": "What is 2+2?"},
|
| 468 |
+
{"role": "output", "content": "4"},
|
| 469 |
+
{"role": "input", "content": "What is 3+3?"},
|
| 470 |
+
{"role": "output", "content": "10"},
|
| 471 |
+
{"role": "input", "content": "What is 4+4?"},
|
| 472 |
+
],
|
| 473 |
+
[
|
| 474 |
+
{"role": "description", "content": "DESCRIPTION"},
|
| 475 |
+
{"role": "input", "content": "INPUT1"},
|
| 476 |
+
{"role": "output", "content": "OUTPUT1"},
|
| 477 |
+
{"role": "input", "content": "INPUT2"},
|
| 478 |
+
{"role": "output", "content": "OUTPUT2"},
|
| 479 |
+
],
|
| 480 |
+
[
|
| 481 |
+
{"role": "description", "content": "DESCRIPTION"},
|
| 482 |
+
{"role": "output", "content": "OUTPUT1"},
|
| 483 |
+
{"role": "output", "content": "OUTPUT2"},
|
| 484 |
+
],
|
| 485 |
+
]
|
| 486 |
+
for messages in test_messages:
|
| 487 |
+
# get messages to text_loss
|
| 488 |
+
texts = custom_tokenizer.messages_to_loss_texts(messages)
|
| 489 |
+
|
| 490 |
+
print("Texts with loss flags:")
|
| 491 |
+
for i, text in enumerate(texts):
|
| 492 |
+
print(f" {i}: {text}")
|
| 493 |
+
|
| 494 |
+
text = custom_tokenizer.messages_to_text(messages, start_generation=True)
|
| 495 |
+
print(f"\nFull text with generation prompt:")
|
| 496 |
+
print(text)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
print("\nTesting save/load cycle:")
|
| 500 |
+
# Test saving and loading
|
| 501 |
+
tokenizer_path = "repos/chat-gemma-tokenizer"
|
| 502 |
+
custom_tokenizer.save_pretrained(tokenizer_path)
|
| 503 |
+
print("Tokenizer saved successfully!")
|
| 504 |
+
|
| 505 |
+
# also save this file in the tokenizer_path
|
| 506 |
+
import shutil
|
| 507 |
+
shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py"))
|
| 508 |
+
print("GemmaChatTokenizer.py saved successfully!")
|
generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 2,
|
| 3 |
+
"cache_implementation": "hybrid",
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
1,
|
| 7 |
+
106
|
| 8 |
+
],
|
| 9 |
+
"pad_token_id": 0,
|
| 10 |
+
"top_k": 64,
|
| 11 |
+
"top_p": 0.95,
|
| 12 |
+
"transformers_version": "4.51.3"
|
| 13 |
+
}
|
model-00001-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:767d7a9fea55e6c77a497deecbe9a96f956f1e97b0e48213f4978b94e4043eb2
|
| 3 |
+
size 4979902192
|
model-00002-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f2b5baaa7eea2abffcf0e846e58d4b84c10c6a15cdbbd11a67072e4522a49f1
|
| 3 |
+
size 4931296592
|
model-00003-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04e9e018d3d0757eb516c9acc1e5ff203a944ecb1a24d9f0d6e3e08a4af75b11
|
| 3 |
+
size 4931296656
|
model-00004-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:702c672a6a1951b6fc81dedb4276ea8e6f146abb0d041aaeec6a0a042c8beb3b
|
| 3 |
+
size 4931296656
|
model-00005-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e00779aa228c203f949575ec7aba50f81deec0ea80c4ecbf0be8c226c89c15f1
|
| 3 |
+
size 4601000928
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"boi_token": "<start_of_image>",
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"content": "<bos>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
"eoi_token": "<end_of_image>",
|
| 11 |
+
"eos_token": {
|
| 12 |
+
"content": "<eos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false
|
| 17 |
+
},
|
| 18 |
+
"image_token": "<image_soft_token>",
|
| 19 |
+
"pad_token": {
|
| 20 |
+
"content": "<pad>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false
|
| 25 |
+
},
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"content": "<unk>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": false,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
|
| 3 |
+
size 33384568
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:267c12b1809a697fee02ca2efc4f0dd076c1f378edc40b9163dacfb3f1028db9
|
| 3 |
+
size 7313
|