tsor13 commited on
Commit
facc2b8
·
verified ·
1 Parent(s): f4a1cbb

Initial upload of fine‑tuned Gemma + custom tokenizer

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### tsor13/Special12b
2
+ The following is a a model trained by [...suspense...] that is meant to:
3
+ - follow instructions better than pretrained models and be more diverse / less mode-collapsed than instruct models;
4
+ - be a really good, approximately bayesian in-context learner;
5
+ - fit an data generation process
6
+ - be calibrated over distributions of possible outputs wrt a population or epistemic uncertainty
7
+ It is initialized from `google/gemma-3-12b-pt`.
8
+
9
+ This model/repo is a work in progress - expect updates.
10
+
11
+ Loading model example:
12
+ ```
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ tokenizer = AutoTokenizer.from_pretrained("tsor13/special12b", trust_remote_code=True) # custom tokenizer for handling messages / loss
15
+ model = AutoModelForCausalLM.from_pretrained("tsor13/special12b", device_map="auto")
16
+ ```
17
+
18
+ It has its own chat-style input messages, with the following roles:
19
+ - `description`(optional): A description of the generating process, or some information meant to instantiate a prior
20
+ - `input` (optional): Any variables that a model is not responsible for predicting, but could be used to condition generation somehow;
21
+ - `output`: This is what the model will actually predict / generate.
22
+
23
+ For example,
24
+ ```
25
+ messages = [
26
+ {"role": "description", "content": "Capitals"},
27
+ {"role": "input", "content": "France"},
28
+ {"role": "output", "content": "Paris"},
29
+ {"role": "input", "content": "Japan"},
30
+ ]
31
+ ```
32
+
33
+ To templatize the messages, you can use the tokenizer:
34
+ ```
35
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
36
+ print(formatted_prompt) # start_generation adds the <start_of_turn> token to condition the model for generation
37
+ ```
38
+ Output:
39
+ ```
40
+ Capitals
41
+ France
42
+ <start_of_turn>Paris<end_of_turn>
43
+ Japan
44
+ <start_of_turn>
45
+ ```
46
+ The data for the model to emulate / generate is wrapped in `<start_of_turn>` / `<end_of_turn>` tokens.
47
+ Description and input is not wrapped in anything. Thus, do not expect the model to generate these tokens - instead focus on the wrapped output tokens.
48
+ Messages are separated by newlines.
49
+
50
+ In training, loss is ONLY calculated on the output tokens and the `<end_of_turn>` token. Thus, the model is only designed to generate / predict probabilities after `<start_of_turn>` and until `<end_of_turn>` - everything else is out of distribution for the model and not recommended.
51
+
52
+ Once you have the formatted text, you can tokenize as normal:
53
+ ```
54
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
55
+ ```
56
+
57
+ Let's look at what the model does. In this case, there is a single correct answer. Let's look at model probabilities after `<start_of_turn>`:
58
+ ```
59
+ import torch
60
+ with torch.no_grad():
61
+ output = model(**inputs)
62
+ logits = output.logits[0, -1, :]
63
+ probs = torch.nn.functional.softmax(logits, dim=-1)
64
+ top_probs, top_indices = torch.topk(probs, 10)
65
+ print("\nTop 10 probabilities for first output token:")
66
+ for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
67
+ token = tokenizer.decode(idx)
68
+ print(f"{i+1:2d}. '{token}' -> {prob.item():.4f}")
69
+ ```
70
+
71
+ Output:
72
+ ```
73
+ Top 10 probabilities for first output token:
74
+ 1. 'Tokyo' -> 0.9764
75
+ 2. 'Tok' -> 0.0070
76
+ 3. '東京' -> 0.0026
77
+ 4. 'Ky' -> 0.0019
78
+ 5. 'T' -> 0.0014
79
+ 6. ' Tokyo' -> 0.0014
80
+ 7. 'To' -> 0.0011
81
+ 8. 'Osaka' -> 0.0009
82
+ 9. 'Toy' -> 0.0007
83
+ 10. 'tok' -> 0.0005
84
+ ```
85
+
86
+ Great! Almost all of the probability mass is on the correct answer, Tokyo.
87
+
88
+ Let's try an example with many possible reasonable choices / a harder to describe distribution. For example, say that I'm interested in modeling "board games that I like". I may be hard-pressed to actually describe what it is that I like about games - but I could provide a few examples pretty easily.
89
+
90
+ ```
91
+ messages = [
92
+ {"role": "output", "content": "Dune: Imperium"},
93
+ {"role": "output", "content": "Acquire"},
94
+ {"role": "output", "content": "Catan"},
95
+ {"role": "output", "content": "Tigris and Euphrates"},
96
+ {"role": "output", "content": "Brass: Birmingham"},
97
+ ]
98
+ ```
99
+
100
+ Given these example outputs, the model will try to generate more outputs like these outputs.
101
+ ```
102
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
103
+ n_gens = 4
104
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
105
+
106
+ outputs = model.generate(**inputs, max_new_tokens=10, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
107
+ for i in range(n_gens):
108
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
109
+ ```
110
+
111
+ Outputs:
112
+ ```
113
+ Catan: Rivals for Catan
114
+ Gloomhaven
115
+ Great Western Trail
116
+ Azul
117
+ ```
118
+ Not too bad!
119
+
120
+ You can also specify just the description:
121
+ Input:
122
+ ```
123
+ messages = [
124
+ {"role": "description", "content": "Descriptive colors"},
125
+ ]
126
+
127
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
128
+ n_gens = 4
129
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
130
+
131
+ outputs = model.generate(**inputs, max_new_tokens=10, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
132
+ for i in range(n_gens):
133
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
134
+ print()
135
+ ```
136
+ Output:
137
+ ```
138
+ Deep Sea Blue
139
+ Gray#222222
140
+ Gold, Red, Black
141
+ I can’t believe we’re already talking about color theory. How is this possible? Can time go any faster? Also how does your body
142
+ ```
143
+
144
+ By default, the model is only trained to do 1) either emulate outputs if examples are provided, or 2) generate data based on the description. Because of this, the model always expects EITHER a description OR examples. If you want it to act slightly more like an instruction following chat model, you can add a description such as the following:
145
+
146
+ ```
147
+ messages = [
148
+ {"role": "description", "content": "You are a helpful assistant who outputs the requested content."},
149
+ {"role": "input", "content": "A poem about a shark"},
150
+ ]
151
+ ```
152
+ To generate:
153
+ ```
154
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
155
+ n_gens = 4
156
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
157
+
158
+ outputs = model.generate(**inputs, max_new_tokens=40, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
159
+ for i in range(n_gens):
160
+ print(f"Generation {i}:")
161
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
162
+ ```
163
+
164
+ Some example generations:
165
+ ```
166
+ Generation 0:
167
+ A deep-sea creature, silent and fierce, Shivers through water, its body sleek. Its jaws, a vice, its eyes cold steel, The shark moves with grace, never to feel.
168
+ of power and danger,
169
+ Generation 1:
170
+ The great white shark lurks in the deep, with teeth so sharp, it could cut a whale in half. Its dorsal fin slices through the water, like a knife through butter, and its tail
171
+ Generation 2:
172
+ The shark swam in the sea, With a toothy grin, as if it could be glee. It was the top of the food chain, The apex of the sea's terrain. With sleek
173
+ Generation 3:
174
+ I am a gentle, tranquil wave, gliding smoothly across the ocean's expanse. Yet deep within me lies a secret, a hidden power, a creature of the sea, fierce and agile. It
175
+ ```
176
+
177
+
178
+
179
+
180
+ Finally, let's look at a synthetic data generation task. For example, maybe we want to generate situations to do social reasoning over, along with whether or not they are awkward. When there are multiple variables to condition on or generat, the model is used to json format.
181
+
182
+ Input:
183
+ ```
184
+ import json
185
+ messages = [
186
+ {"role": "description", "content": "Situations to do social reasoning over, along with whether or not it is an awkward situation."},
187
+ {"role": "output", "content": json.dumps({
188
+ "situation": "You're at a party and you realize that your shirt is on backwards.",
189
+ "is_awkward": True,
190
+ })},
191
+ {"role": "output", "content": json.dumps({
192
+ "situation": "While at work, your boss commends you on a job well done.",
193
+ "is_awkward": False,
194
+ })},
195
+ {"role": "output", "content": json.dumps({
196
+ "situation": "Realizing you forgot to bring your passport to the airport.",
197
+ "is_awkward": True,
198
+ })},
199
+ ]
200
+
201
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
202
+ n_gens = 4
203
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
204
+
205
+ outputs = model.generate(**inputs, max_new_tokens=40, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
206
+ for i in range(n_gens):
207
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
208
+ ```
209
+ Output:
210
+ ```
211
+ {"situation": "While walking on the street, someone waves and smiles at you, but you don't know them.", "is_awkward": false}
212
+ {"situation": "Taking a cab and giving the driver wrong directions.", "is_awkward": true}
213
+ {"situation": "Being told that an individual you've had a long-term crush on is also crushing on someone else.", "is_awkward": true}
214
+ {"situation": "Watching a loved one get proposed to.", "is_awkward": false}
215
+ ```
216
+
217
+ A few tips and tricks:
218
+ - Do not expect the model to do multi-turn chats. It is designed to be stateless and to treat each data point as "exchangeable" (roughly iid).
219
+ - If all you want is one reasonable answer, then a chat model is likely a better fit. However, if you want to generate many reasonable answers / diverse examples, this model is a better fit.
220
+ - The model is quite good at perspective taking / steering if you provide many examples.
221
+ - The model is reasonably good at expressing epistemic uncertainty over unsure outputs by sampling several times.
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3ForConditionalGeneration"
4
+ ],
5
+ "boi_token_index": 255999,
6
+ "eoi_token_index": 256000,
7
+ "eos_token_id": [
8
+ 1,
9
+ 106
10
+ ],
11
+ "image_token_index": 262144,
12
+ "initializer_range": 0.02,
13
+ "mm_tokens_per_image": 256,
14
+ "model_type": "gemma3",
15
+ "text_config": {
16
+ "attention_bias": false,
17
+ "attention_dropout": 0.0,
18
+ "attn_logit_softcapping": null,
19
+ "cache_implementation": "hybrid",
20
+ "final_logit_softcapping": null,
21
+ "head_dim": 256,
22
+ "hidden_activation": "gelu_pytorch_tanh",
23
+ "hidden_size": 3840,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 15360,
26
+ "max_position_embeddings": 131072,
27
+ "model_type": "gemma3_text",
28
+ "num_attention_heads": 16,
29
+ "num_hidden_layers": 48,
30
+ "num_key_value_heads": 8,
31
+ "query_pre_attn_scalar": 256,
32
+ "rms_norm_eps": 1e-06,
33
+ "rope_local_base_freq": 10000.0,
34
+ "rope_scaling": {
35
+ "factor": 8.0,
36
+ "rope_type": "linear"
37
+ },
38
+ "rope_theta": 1000000.0,
39
+ "sliding_window": 1024,
40
+ "sliding_window_pattern": 6,
41
+ "torch_dtype": "float32",
42
+ "use_cache": true,
43
+ "vocab_size": 262208
44
+ },
45
+ "torch_dtype": "bfloat16",
46
+ "transformers_version": "4.51.3",
47
+ "vision_config": {
48
+ "attention_dropout": 0.0,
49
+ "hidden_act": "gelu_pytorch_tanh",
50
+ "hidden_size": 1152,
51
+ "image_size": 896,
52
+ "intermediate_size": 4304,
53
+ "layer_norm_eps": 1e-06,
54
+ "model_type": "siglip_vision_model",
55
+ "num_attention_heads": 16,
56
+ "num_channels": 3,
57
+ "num_hidden_layers": 27,
58
+ "patch_size": 14,
59
+ "torch_dtype": "float32",
60
+ "vision_use_head": false
61
+ }
62
+ }
gemma_chat_tokenizer.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Gemma Tokenizer for chat Format
3
+
4
+ This tokenizer implements the chat format for message processing:
5
+ Format: Uses the standard chat template with proper role labels (user/assistant)
6
+
7
+ The chat format uses the model's built-in chat template and includes proper
8
+ loss computation flags for training with "assistant" as the generation role.
9
+
10
+ To save:
11
+ uv run tokenizers/gemma_chat_tokenizer.py
12
+ which will save the tokenizer to the repos/chat-gemma-tokenizer directory.
13
+ mkdir repos/chat12b
14
+ # copy model over
15
+ cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/
16
+ # copy tokenizer over
17
+ cp repos/chat-gemma-tokenizer/* repos/chat12b/
18
+ # upload to hf
19
+
20
+ uv run upload_to_hf.py \
21
+ --folder repos/chat12b \
22
+ --repo-id tsor13/chat12b
23
+ """
24
+
25
+ from typing import List, Dict, Any, Optional, Union
26
+ from transformers import AutoTokenizer
27
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
28
+ from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
29
+ import warnings
30
+ import difflib
31
+ import json
32
+ import os
33
+ import sys
34
+
35
+
36
+ class GemmaChatTokenizer(GemmaTokenizerFast):
37
+ """
38
+ Custom tokenizer for Gemma models that implements chat format message processing.
39
+
40
+ This tokenizer formats messages using the chat format where:
41
+ - Messages use the standard chat template with proper role labels
42
+ - Uses the model's built-in chat formatting
43
+ - Loss is computed on the assistant sections (not output)
44
+
45
+ Attributes:
46
+ start_string (str): The starting string used for output generation (depends on tokenizer)
47
+ end_string (str): The ending string used for output generation (depends on tokenizer)
48
+ """
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ """
52
+ Initialize the custom tokenizer.
53
+
54
+ Accepts the same arguments as GemmaTokenizerFast.
55
+ """
56
+ super().__init__(*args, **kwargs)
57
+
58
+ # For chat format, we use the tokenizer's own chat template
59
+ # The start/end strings will be determined by the chat template
60
+ self.start_string = "<start_of_turn>" # Will be set dynamically
61
+ self.end_string = "<end_of_turn>" # Will be set dynamically
62
+
63
+ # Add custom attributes to the tokenizer config for saving/loading
64
+ if not hasattr(self, 'init_kwargs'):
65
+ self.init_kwargs = {}
66
+ self.init_kwargs['start_string'] = self.start_string
67
+ self.init_kwargs['end_string'] = self.end_string
68
+
69
+ @classmethod
70
+ def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
71
+ """
72
+ Load a tokenizer from a pretrained model or path.
73
+
74
+ This method ensures our custom class is used instead of the base GemmaTokenizerFast.
75
+ """
76
+ # Load the base tokenizer first to get all configuration
77
+ base_tokenizer = GemmaTokenizerFast.from_pretrained(
78
+ pretrained_model_name_or_path, *args, **kwargs
79
+ )
80
+
81
+ # Create new instance of our custom class by copying the base tokenizer
82
+ custom_tokenizer = cls.__new__(cls)
83
+
84
+ # Copy all attributes from base tokenizer
85
+ for attr, value in base_tokenizer.__dict__.items():
86
+ setattr(custom_tokenizer, attr, value)
87
+
88
+ # Initialize our custom attributes for chat format
89
+ custom_tokenizer.start_string = "<start_of_turn>"
90
+ custom_tokenizer.end_string = "<end_of_turn>"
91
+
92
+ # Update init_kwargs to include our custom attributes
93
+ if not hasattr(custom_tokenizer, 'init_kwargs'):
94
+ custom_tokenizer.init_kwargs = {}
95
+ custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
96
+ custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
97
+
98
+ return custom_tokenizer
99
+
100
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
101
+ """
102
+ Save the tokenizer to a directory, including custom configuration.
103
+ """
104
+ # Call parent save method
105
+ super().save_pretrained(save_directory, **kwargs)
106
+
107
+ # Save custom configuration
108
+ config_file = os.path.join(save_directory, "tokenizer_config.json")
109
+ if os.path.exists(config_file):
110
+ with open(config_file, 'r') as f:
111
+ config = json.load(f)
112
+ else:
113
+ config = {}
114
+
115
+ # Add our custom class info
116
+ config["tokenizer_class"] = "GemmaChatTokenizer"
117
+ config["start_string"] = self.start_string
118
+ config["end_string"] = self.end_string
119
+ # Point to our custom class in the uploaded file
120
+ config["auto_map"] = {
121
+ "AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"]
122
+ }
123
+
124
+ with open(config_file, 'w') as f:
125
+ json.dump(config, f, indent=2)
126
+
127
+
128
+ def messages_to_loss_texts(
129
+ self,
130
+ messages: List[Dict[str, Any]],
131
+ loss_on_start_token: bool = False,
132
+ default_user_message: str = "Generate.",
133
+ start_generation: bool = False,
134
+ ) -> List[Dict[str, Any]]:
135
+ """
136
+ From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
137
+ Uses the chat format matching chat_utils.py.
138
+ """
139
+ # FOR NOW, OVERRIDING TO FALSE
140
+ loss_on_start_token = False
141
+
142
+ texts = []
143
+ chat_messages = []
144
+ has_input = False
145
+ has_description = False
146
+
147
+ # Convert to chat format
148
+ for message in messages:
149
+ if message["role"] == "description":
150
+ chat_messages.append({"role": "system", "content": "Generate something that fits this description. Don't generate anything else, just the desired generation output.\nDescription: " + message["content"]})
151
+ has_description = True
152
+ elif message["role"] == "input":
153
+ has_input = True
154
+ chat_messages.append({"role": "user", "content": message["content"]})
155
+ elif message["role"] == "output":
156
+ if not has_input:
157
+ chat_messages.append({"role": "user", "content": default_user_message})
158
+ chat_messages.append({"role": "assistant", "content": message["content"]})
159
+ # if last message is output and start_generation is true, add a default user message
160
+ if start_generation and chat_messages[-1]["role"] == "assistant":
161
+ chat_messages.append({"role": "user", "content": default_user_message})
162
+
163
+ # Apply chat template
164
+ full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation)
165
+ # replace <bos> with nothing
166
+ full_text = full_text.replace("<bos>", "")
167
+
168
+ text_to_split = full_text
169
+ # now, find all places starting with <start_of_turn>model\n
170
+ model_start_text = "<start_of_turn>model\n" # TODO - manual for now, change later
171
+ first = True
172
+ while model_start_text in text_to_split:
173
+ # get location of model_start_text
174
+ model_start_loc = text_to_split.find(model_start_text)
175
+ split_ind = model_start_loc + len(model_start_text)
176
+ text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:]
177
+ # add to texts
178
+ texts.append({"text": text_to_add, "compute_loss": False})
179
+ # get location of end_string
180
+ end_string_loc = text_to_split.find(self.end_string)
181
+ end_ind = end_string_loc + len(self.end_string)
182
+ text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:]
183
+ if first and not has_description:
184
+ texts.append({"text": text_to_add, "compute_loss": False})
185
+ else:
186
+ texts.append({"text": text_to_add, "compute_loss": True})
187
+ first = False
188
+ if len(text_to_split) > 0:
189
+ texts.append({"text": text_to_split, "compute_loss": False})
190
+
191
+ return texts
192
+
193
+ def messages_to_text(
194
+ self,
195
+ messages: List[Dict[str, Any]],
196
+ start_generation: bool = False,
197
+ ) -> str:
198
+ """
199
+ Messages (description / input / output) to raw text (text).
200
+ Uses the chat format matching chat_utils.py.
201
+ """
202
+ texts = self.messages_to_loss_texts(messages, start_generation=start_generation)
203
+ text = "".join([text["text"] for text in texts])
204
+ return text
205
+
206
+
207
+ def tokenize_messages(
208
+ self,
209
+ messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
210
+ start_generation: bool = False,
211
+ **kwargs,
212
+ ):
213
+ """
214
+ For tokenizing from messages to texts. Supports batching. Good for generation
215
+ """
216
+ if isinstance(messages, list) and isinstance(messages[0], list):
217
+ # Handle list of lists of messages
218
+ all_texts = []
219
+ for message_list in messages:
220
+ texts = self.messages_to_text(message_list, start_generation)
221
+ all_texts.append(texts)
222
+ else:
223
+ # Handle single list of messages
224
+ texts = self.messages_to_text(messages, start_generation)
225
+ all_texts = [texts]
226
+
227
+ # Tokenize all texts
228
+ processed = self(text=all_texts, **kwargs)
229
+ return processed
230
+
231
+
232
+ def tokenize_loss_texts(
233
+ self,
234
+ texts: List[Dict[str, Any]],
235
+ loss_on_start_token: bool = False,
236
+ loss_on_eos: bool = False,
237
+ include_eos: bool = True,
238
+ ):
239
+ """
240
+ Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
241
+
242
+ Needs more complex logic to handle the back and forth labeling.
243
+ """
244
+ if loss_on_eos:
245
+ raise ValueError("Loss on EOS is not currently supported.")
246
+
247
+ # Handle single string input
248
+ if isinstance(texts, str):
249
+ processed = self(text=texts)
250
+ # Add EOS token if needed
251
+ if (self.eos_token_id is not None and
252
+ processed["input_ids"][-1] != self.eos_token_id):
253
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
254
+ processed["attention_mask"] = processed["attention_mask"] + [1]
255
+ return processed
256
+
257
+ # Handle list of text dictionaries
258
+ all_processed = []
259
+ all_texts = ''
260
+ example_inds = []
261
+ dataset_inds = []
262
+
263
+ for i, item in enumerate(texts):
264
+ processed = self(text=item["text"])
265
+
266
+ # Remove BOS token from all but first item
267
+ if i != 0 and self.bos_token_id == processed["input_ids"][0]:
268
+ processed["input_ids"] = processed["input_ids"][1:]
269
+ processed["attention_mask"] = processed["attention_mask"][1:]
270
+
271
+ # Remove EOS token if present at the end
272
+ if processed["input_ids"][-1] == self.eos_token_id:
273
+ processed["input_ids"] = processed["input_ids"][:-1]
274
+ processed["attention_mask"] = processed["attention_mask"][:-1]
275
+
276
+ # Check for EOS token in the middle (with special handling for <|im_end|>)
277
+ if self.eos_token_id in processed["input_ids"]:
278
+ if not self.decode([self.eos_token_id]) == "<|im_end|>":
279
+ raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
280
+
281
+ # Set labels based on compute_loss flag
282
+ if item["compute_loss"]:
283
+ processed["labels"] = processed["input_ids"].copy()
284
+ else:
285
+ processed["labels"] = [-100] * len(processed["input_ids"])
286
+
287
+ # Remove duplicate BOS tokens
288
+ if all_processed:
289
+ if processed["input_ids"][0] == self.bos_token_id:
290
+ processed["input_ids"] = processed["input_ids"][1:]
291
+ processed["attention_mask"] = processed["attention_mask"][1:]
292
+ processed["labels"] = processed["labels"][1:]
293
+
294
+ all_processed.append(processed)
295
+ all_texts += item["text"]
296
+
297
+ # Handle example indices
298
+ this_num = -1
299
+ if 'example_ind' in item.keys():
300
+ if item["example_ind"] is not None:
301
+ this_num = item["example_ind"]
302
+ example_inds.extend([this_num] * len(processed["input_ids"]))
303
+
304
+ # Handle dataset indices
305
+ dataset_ind = -1
306
+ if "data_id" in item.keys():
307
+ if item["data_id"] is not None:
308
+ dataset_ind = item["data_id"]
309
+ dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
310
+
311
+ # Combine all processed results
312
+ processed = all_processed[0].copy()
313
+ processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
314
+ processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
315
+ processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
316
+ processed["example_inds"] = example_inds
317
+ processed["data_ids"] = dataset_inds
318
+
319
+ # Validate by tokenizing all_texts at once and comparing
320
+ processed_all = self(text=all_texts)
321
+ if len(processed_all["input_ids"]) != len(processed["input_ids"]):
322
+ warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
323
+
324
+ # Generate diff for debugging
325
+ all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
326
+ processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
327
+
328
+ diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
329
+ diff_str = "\n".join(diff)
330
+ print("Diff between texts:")
331
+ print(diff_str)
332
+
333
+ # Token diff
334
+ all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
335
+ processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
336
+ token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
337
+ token_diff_str = "\n".join(token_diff)
338
+ print("Diff between tokenized texts:")
339
+ print(token_diff_str)
340
+
341
+ # Add EOS token if needed
342
+ if (self.eos_token_id is not None and
343
+ processed["input_ids"][-1] != self.eos_token_id):
344
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
345
+ processed["example_inds"] = processed["example_inds"] + [-1]
346
+ processed["attention_mask"] = processed["attention_mask"] + [1]
347
+ if processed["labels"] is not None:
348
+ if loss_on_eos:
349
+ processed["labels"] = processed["labels"] + [self.eos_token_id]
350
+ else:
351
+ processed["labels"] = processed["labels"] + [-100]
352
+ if "data_ids" in processed:
353
+ processed["data_ids"] = processed["data_ids"] + [-1]
354
+
355
+ if not include_eos:
356
+ # check if EOS token is present
357
+ if processed["input_ids"][-1] == self.eos_token_id:
358
+ # remove EOS token
359
+ processed["input_ids"] = processed["input_ids"][:-1]
360
+ processed["attention_mask"] = processed["attention_mask"][:-1]
361
+ processed["labels"] = processed["labels"][:-1]
362
+ processed["example_inds"] = processed["example_inds"][:-1]
363
+ processed["data_ids"] = processed["data_ids"][:-1]
364
+
365
+ return processed
366
+
367
+ def tokenize_messages(
368
+ self,
369
+ messages: List[Dict[str, Any]],
370
+ loss_on_start_token: bool = False,
371
+ loss_on_eos: bool = False,
372
+ include_eos: bool = True,
373
+ ) -> Dict[str, Any]:
374
+ """
375
+ Intended for tokenize from messages to tokenized texts with the loss applied.
376
+ """
377
+ # First convert messages to text with loss computation flags
378
+ texts = self.messages_to_loss_texts(messages, loss_on_start_token)
379
+
380
+ # Then tokenize the texts
381
+ return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos)
382
+
383
+
384
+
385
+
386
+ # Register tokenizer classes for AutoTokenizer
387
+ AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer)
388
+
389
+
390
+ if __name__ == "__main__":
391
+ # Example usage
392
+ # for first load
393
+ custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
394
+
395
+ # for subsequent loads
396
+ # custom_tokenizer = GemmaChatTokenizer.from_pretrained("tsor13/chat-gemma-12b-pt")
397
+ # custom_tokenizer = GemmaChatTokenizer.from_pretrained("repos/chat-gemma-12b-pt")
398
+
399
+ # Test messages in role/content format
400
+ test_messages = [
401
+ [
402
+ {"role": "description", "content": "This is a test task"},
403
+ {"role": "input", "content": "What is 2+2?"},
404
+ {"role": "output", "content": "4"},
405
+ {"role": "input", "content": "What is 3+3?"},
406
+ ],
407
+ [
408
+ {"role": "description", "content": "This is a test task"},
409
+ {"role": "output", "content": "4"},
410
+ {"role": "output", "content": "10"},
411
+ {"role": "output", "content": "13"},
412
+ ],
413
+ [
414
+ {"role": "output", "content": "4"},
415
+ {"role": "output", "content": "10"},
416
+ {"role": "output", "content": "13"},
417
+ ],
418
+ [
419
+ {"role": "input", "content": "What is 2+2?"},
420
+ {"role": "output", "content": "4"},
421
+ {"role": "input", "content": "What is 3+3?"},
422
+ {"role": "output", "content": "10"},
423
+ {"role": "input", "content": "What is 4+4?"},
424
+ ],
425
+ ]
426
+ for messages in test_messages:
427
+ # get messages to text_loss
428
+ texts = custom_tokenizer.messages_to_loss_texts(messages)
429
+
430
+ print("Texts with loss flags:")
431
+ for i, text in enumerate(texts):
432
+ print(f" {i}: {text}")
433
+
434
+ text = custom_tokenizer.messages_to_text(messages, start_generation=True)
435
+ print(f"\nFull text with generation prompt:")
436
+ print(text)
437
+
438
+
439
+ print("\nTesting save/load cycle:")
440
+ # Test saving and loading
441
+ tokenizer_path = "repos/chat-gemma-tokenizer"
442
+ custom_tokenizer.save_pretrained(tokenizer_path)
443
+ print("Tokenizer saved successfully!")
444
+
445
+ # also save this file in the tokenizer_path
446
+ import shutil
447
+ shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py"))
448
+ print("GemmaChatTokenizer.py saved successfully!")
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 2,
3
+ "cache_implementation": "hybrid",
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 1,
7
+ 106
8
+ ],
9
+ "pad_token_id": 0,
10
+ "top_k": 64,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.3"
13
+ }
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6cef91d3c32c85e95d0b452e190b7701b85ee4ff313d483349b24290f62afee
3
+ size 4979902192
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0accacd83d92614807758dfc0661690fc9f3d9f6556ae8b46e61de5a70d3bdbf
3
+ size 4931296592
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4147b87360deb84e430607bf3686e5ec29d2ab175978d8be42294151f7b6e26d
3
+ size 4931296656
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14c35cdcf195c4ac45cdaacb811f0f8c7e4130de125de48aea9c700bb9789a45
3
+ size 4931296656
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e287f5a199f0d42dac15a5434426f96ea2e2c44a6ac1b6981a1c73338f67aa6
3
+ size 4601000928
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24572f35d542bcfa4879d51a8c08dbd9d362bc088e00db35b31d5234790d8fe5
3
+ size 7377