tsor13 commited on
Commit
13efc11
·
verified ·
1 Parent(s): 779540b

Initial upload of fine‑tuned Gemma + custom tokenizer

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3ForConditionalGeneration"
4
+ ],
5
+ "boi_token_index": 255999,
6
+ "eoi_token_index": 256000,
7
+ "image_token_index": 262144,
8
+ "initializer_range": 0.02,
9
+ "mm_tokens_per_image": 256,
10
+ "model_type": "gemma3",
11
+ "text_config": {
12
+ "attention_bias": false,
13
+ "attention_dropout": 0.0,
14
+ "attn_logit_softcapping": null,
15
+ "cache_implementation": "hybrid",
16
+ "final_logit_softcapping": null,
17
+ "head_dim": 256,
18
+ "hidden_activation": "gelu_pytorch_tanh",
19
+ "hidden_size": 3840,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 15360,
22
+ "max_position_embeddings": 131072,
23
+ "model_type": "gemma3_text",
24
+ "num_attention_heads": 16,
25
+ "num_hidden_layers": 48,
26
+ "num_key_value_heads": 8,
27
+ "query_pre_attn_scalar": 256,
28
+ "rms_norm_eps": 1e-06,
29
+ "rope_local_base_freq": 10000.0,
30
+ "rope_scaling": {
31
+ "factor": 8.0,
32
+ "rope_type": "linear"
33
+ },
34
+ "rope_theta": 1000000.0,
35
+ "sliding_window": 1024,
36
+ "sliding_window_pattern": 6,
37
+ "torch_dtype": "float32",
38
+ "use_cache": true,
39
+ "vocab_size": 262208
40
+ },
41
+ "torch_dtype": "bfloat16",
42
+ "transformers_version": "4.51.3",
43
+ "vision_config": {
44
+ "attention_dropout": 0.0,
45
+ "hidden_act": "gelu_pytorch_tanh",
46
+ "hidden_size": 1152,
47
+ "image_size": 896,
48
+ "intermediate_size": 4304,
49
+ "layer_norm_eps": 1e-06,
50
+ "model_type": "siglip_vision_model",
51
+ "num_attention_heads": 16,
52
+ "num_channels": 3,
53
+ "num_hidden_layers": 27,
54
+ "patch_size": 14,
55
+ "torch_dtype": "float32",
56
+ "vision_use_head": false
57
+ }
58
+ }
gemma_explicit_tokenizer.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Gemma Tokenizer for explicit Format
3
+
4
+ This tokenizer implements the explicit format for message processing:
5
+ Format: Uses the standard chat template with proper role labels (user/assistant)
6
+
7
+ The explicit format uses the model's built-in chat template and includes proper
8
+ loss computation flags for training.
9
+
10
+ To save:
11
+ uv run tokenizers/gemma_explicit_tokenizer.py
12
+ which will save the tokenizer to the repos/explicit-gemma-tokenizer directory.
13
+ mkdir repos/explicit12b
14
+ # copy model over
15
+ cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_explicit/checkpoint-8/* repos/explicit12b/
16
+ # copy tokenizer over
17
+ cp repos/explicit-gemma-tokenizer/* repos/explicit12b/
18
+ # upload to hf
19
+
20
+ uv run upload_to_hf.py \
21
+ --folder repos/explicit12b \
22
+ --repo-id tsor13/explicit12b
23
+ """
24
+
25
+ from typing import List, Dict, Any, Optional, Union
26
+ from transformers import AutoTokenizer
27
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
28
+ from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
29
+ import warnings
30
+ import difflib
31
+ import json
32
+ import os
33
+
34
+ # COULD INCLUDE IF WE WANT TO USE THE CHAT TEMPLATE, BUT REMOVING FOR NOW
35
+ # CUSTOM_CHAT_TEMPLATE = r"""
36
+ # {{ bos_token }}{{ '<start_of_turn>description\n' }}
37
+ # {%- if messages and messages[0]['role'] == 'system' -%}
38
+ # {%- if messages[0]['content'] is string -%}
39
+ # {{ messages[0]['content'] | trim }}
40
+ # {%- else -%}
41
+ # {{ messages[0]['content'][0]['text'] | trim }}
42
+ # {%- endif -%}
43
+ # {%- set loop_messages = messages[1:] -%}
44
+ # {%- else -%}
45
+ # You are a helpful assistant.
46
+ # {%- set loop_messages = messages -%}
47
+ # {%- endif -%}
48
+ # {{ '<end_of_turn>' }}
49
+ # {# ----- regular turns (input/output) ----- #}
50
+ # {%- for message in loop_messages -%}
51
+ # {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
52
+ # {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
53
+ # {%- endif -%}
54
+ # {%- if (message['role'] == 'assistant') -%}
55
+ # {%- set role = "output" -%}
56
+ # {%- elif (message['role'] == 'user') -%}
57
+ # {%- set role = "input" -%}
58
+ # {%- else -%}
59
+ # {%- set role = message['role'] -%}
60
+ # {%- endif -%}
61
+ # {{ '<start_of_turn>' + role + '\n' }}
62
+ # {%- if message['content'] is string -%}
63
+ # {{ message['content'] | trim }}
64
+ # {%- elif message['content'] is iterable -%}
65
+ # {%- for item in message['content'] -%}
66
+ # {%- if item['type'] == 'image' -%}
67
+ # {{ '<start_of_image>' }}
68
+ # {%- elif item['type'] == 'text' -%}
69
+ # {{ item['text'] | trim }}
70
+ # {%- endif -%}
71
+ # {%- endfor -%}
72
+ # {%- else -%}
73
+ # {{ raise_exception("Invalid content type") }}
74
+ # {%- endif -%}
75
+ # {{ '<end_of_turn>\n' }}
76
+ # {%- endfor -%}
77
+ # {%- if add_generation_prompt -%}
78
+ # {{ '<start_of_turn>output\n' }}
79
+ # {%- endif -%}
80
+ # """.strip("\n")
81
+
82
+
83
+ class GemmaExplicitTokenizer(GemmaTokenizerFast):
84
+ """
85
+ Custom tokenizer for Gemma models that implements explicit format message processing.
86
+
87
+ This tokenizer formats messages using the explicit format where:
88
+ - Messages use the standard chat template with proper role labels
89
+ - Uses the model's built-in chat formatting
90
+ - Loss is computed on the assistant/output sections
91
+
92
+ Attributes:
93
+ start_string (str): The starting string used for output generation (depends on tokenizer)
94
+ end_string (str): The ending string used for output generation (depends on tokenizer)
95
+ """
96
+
97
+ def __init__(self, *args, **kwargs):
98
+ """
99
+ Initialize the custom tokenizer.
100
+
101
+ Accepts the same arguments as GemmaTokenizerFast.
102
+ """
103
+ super().__init__(*args, **kwargs)
104
+
105
+ # For explicit format, we use the tokenizer's own chat template
106
+ # The start/end strings will be determined by the chat template
107
+ self.start_string = "<start_of_turn>"
108
+ self.end_string = "<end_of_turn>"
109
+
110
+ # # Add custom attributes to the tokenizer config for saving/loading
111
+ # if not hasattr(self, 'init_kwargs'):
112
+ # self.init_kwargs = {}
113
+ # self.init_kwargs['start_string'] = self.start_string
114
+ # self.init_kwargs['end_string'] = self.end_string
115
+ # self.init_kwargs['chat_template'] = CUSTOM_CHAT_TEMPLATE
116
+ if not hasattr(self, 'init_kwargs'):
117
+ self.init_kwargs = {}
118
+ self.init_kwargs['start_string'] = self.start_string
119
+ self.init_kwargs['end_string'] = self.end_string
120
+ # CRITICAL: set the live attribute so apply_chat_template uses it now
121
+ # self.chat_template = CUSTOM_CHAT_TEMPLATE
122
+ # self.init_kwargs['chat_template'] = CUSTOM_CHAT_TEMPLATE
123
+
124
+ @classmethod
125
+ def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
126
+ """
127
+ Load a tokenizer from a pretrained model or path.
128
+
129
+ This method ensures our custom class is used instead of the base GemmaTokenizerFast.
130
+ """
131
+ # Load the base tokenizer first to get all configuration
132
+ base_tokenizer = GemmaTokenizerFast.from_pretrained(
133
+ pretrained_model_name_or_path, *args, **kwargs
134
+ )
135
+
136
+ # Create new instance of our custom class by copying the base tokenizer
137
+ custom_tokenizer = cls.__new__(cls)
138
+
139
+ # Copy all attributes from base tokenizer
140
+ for attr, value in base_tokenizer.__dict__.items():
141
+ setattr(custom_tokenizer, attr, value)
142
+
143
+ # Initialize our custom attributes for explicit format
144
+ custom_tokenizer.start_string = "<start_of_turn>"
145
+ custom_tokenizer.end_string = "<end_of_turn>"
146
+
147
+ # Update init_kwargs to include our custom attributes
148
+ if not hasattr(custom_tokenizer, 'init_kwargs'):
149
+ custom_tokenizer.init_kwargs = {}
150
+ custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
151
+ custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
152
+ # custom_tokenizer.init_kwargs['chat_template'] = CUSTOM_CHAT_TEMPLATE
153
+
154
+ return custom_tokenizer
155
+
156
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
157
+ """
158
+ Save the tokenizer to a directory, including custom configuration.
159
+ """
160
+ # Call parent save method
161
+ super().save_pretrained(save_directory, **kwargs)
162
+
163
+ # Save custom configuration
164
+ config_file = os.path.join(save_directory, "tokenizer_config.json")
165
+ if os.path.exists(config_file):
166
+ with open(config_file, 'r') as f:
167
+ config = json.load(f)
168
+ else:
169
+ config = {}
170
+
171
+ # Add our custom class info
172
+ config["tokenizer_class"] = "GemmaExplicitTokenizer"
173
+ config["start_string"] = self.start_string
174
+ config["end_string"] = self.end_string
175
+ # config["chat_template"] = CUSTOM_CHAT_TEMPLATE
176
+ # Point to our custom class in the uploaded file
177
+ config["auto_map"] = {
178
+ "AutoTokenizer": ["gemma_explicit_tokenizer.GemmaExplicitTokenizer", "gemma_explicit_tokenizer.GemmaExplicitTokenizer"]
179
+ }
180
+
181
+ with open(config_file, 'w') as f:
182
+ json.dump(config, f, indent=2)
183
+
184
+ def messages_to_loss_texts(
185
+ self,
186
+ messages: List[Dict[str, Any]],
187
+ start_generation: bool = False,
188
+ ) -> List[Dict[str, Any]]:
189
+ """
190
+ From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
191
+ Uses the explicit format matching chat_utils.py.
192
+ """
193
+
194
+ # Gemma-3 explicit parameters copied from chat_utils.py
195
+ description_map = lambda x: [{
196
+ "text": "<start_of_turn>description\n" + x + "<end_of_turn>\n",
197
+ "compute_loss": False,
198
+ }]
199
+ input_map = lambda x: [{
200
+ "text": "<start_of_turn>input\n" + x + "<end_of_turn>\n",
201
+ "compute_loss": False,
202
+ }]
203
+ output_map = lambda x: [{
204
+ "text": "<start_of_turn>output\n",
205
+ "compute_loss": False,
206
+ },{
207
+ "text": x + "<end_of_turn>",
208
+ "compute_loss": True,
209
+ },{
210
+ "text": "\n",
211
+ "compute_loss": False,
212
+ }]
213
+
214
+ texts = []
215
+ has_description = False
216
+ first_output = True
217
+
218
+ for message in messages:
219
+ role = message["role"]
220
+ content = message["content"]
221
+
222
+ if role == "description":
223
+ has_description = True
224
+ texts.extend(description_map(content))
225
+ elif role == "input":
226
+ texts.extend(input_map(content))
227
+ elif role == "output":
228
+ out_texts = output_map(content)
229
+ if first_output and not has_description:
230
+ # set compute_loss to False for all
231
+ for text in out_texts:
232
+ text["compute_loss"] = False
233
+ texts.extend(out_texts)
234
+ first_output = False
235
+ else:
236
+ raise ValueError(f"Unknown role: {role}. Must be description, input, or output.")
237
+
238
+ # Add generation prompt if start_generation is True
239
+ if start_generation:
240
+ start_generation_text = "<start_of_turn>output\n"
241
+ texts.extend([{"text": start_generation_text, "compute_loss": False}])
242
+
243
+ return texts
244
+
245
+ def messages_to_text(
246
+ self,
247
+ messages: List[Dict[str, Any]],
248
+ start_generation: bool = False,
249
+ ) -> str:
250
+ """
251
+ Messages (description / input / output) to raw text (text).
252
+ Uses the explicit format matching chat_utils.py.
253
+ """
254
+ texts = self.messages_to_loss_texts(messages, start_generation=start_generation)
255
+ text = "".join([text["text"] for text in texts])
256
+ return text
257
+
258
+
259
+ def tokenize_messages(
260
+ self,
261
+ messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
262
+ start_generation: bool = False,
263
+ **kwargs,
264
+ ):
265
+ """
266
+ For tokenizing from messages to texts. Supports batching. Good for generation
267
+ """
268
+ if isinstance(messages, list) and isinstance(messages[0], list):
269
+ # Handle list of lists of messages
270
+ all_texts = []
271
+ for message_list in messages:
272
+ texts = self.messages_to_text(message_list, start_generation)
273
+ all_texts.append(texts)
274
+ else:
275
+ # Handle single list of messages
276
+ texts = self.messages_to_text(messages, start_generation)
277
+ all_texts = [texts]
278
+
279
+ # Tokenize all texts
280
+ processed = self(text=all_texts, **kwargs)
281
+ # if start_generation, remove the last token if it is the eos token
282
+ if start_generation and processed["input_ids"][-1] == self.eos_token_id:
283
+ processed["input_ids"] = processed["input_ids"][:-1]
284
+ processed["attention_mask"] = processed["attention_mask"][:-1]
285
+ processed["labels"] = processed["labels"][:-1]
286
+ return processed
287
+
288
+
289
+ def tokenize_loss_texts(
290
+ self,
291
+ texts: List[Dict[str, Any]],
292
+ loss_on_eos: bool = False,
293
+ include_eos: bool = True,
294
+ ):
295
+ """
296
+ Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
297
+
298
+ Needs more complex logic to handle the back and forth labeling.
299
+ """
300
+ if loss_on_eos:
301
+ raise ValueError("Loss on EOS is not currently supported.")
302
+
303
+ # Handle single string input
304
+ if isinstance(texts, str):
305
+ processed = self(text=texts)
306
+ # Add EOS token if needed
307
+ if (self.eos_token_id is not None and
308
+ processed["input_ids"][-1] != self.eos_token_id):
309
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
310
+ processed["attention_mask"] = processed["attention_mask"] + [1]
311
+ return processed
312
+
313
+ # Handle list of text dictionaries
314
+ all_processed = []
315
+ all_texts = ''
316
+ example_inds = []
317
+ dataset_inds = []
318
+
319
+ for i, item in enumerate(texts):
320
+ processed = self(text=item["text"])
321
+
322
+ # Remove BOS token from all but first item
323
+ if i != 0 and self.bos_token_id == processed["input_ids"][0]:
324
+ processed["input_ids"] = processed["input_ids"][1:]
325
+ processed["attention_mask"] = processed["attention_mask"][1:]
326
+
327
+ # Remove EOS token if present at the end
328
+ if processed["input_ids"][-1] == self.eos_token_id:
329
+ processed["input_ids"] = processed["input_ids"][:-1]
330
+ processed["attention_mask"] = processed["attention_mask"][:-1]
331
+
332
+ # Check for EOS token in the middle (with special handling for <|im_end|>)
333
+ if self.eos_token_id in processed["input_ids"]:
334
+ if not self.decode([self.eos_token_id]) == "<|im_end|>":
335
+ raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
336
+
337
+ # Set labels based on compute_loss flag
338
+ if item["compute_loss"]:
339
+ processed["labels"] = processed["input_ids"].copy()
340
+ else:
341
+ processed["labels"] = [-100] * len(processed["input_ids"])
342
+
343
+ # Remove duplicate BOS tokens
344
+ if all_processed:
345
+ if processed["input_ids"][0] == self.bos_token_id:
346
+ processed["input_ids"] = processed["input_ids"][1:]
347
+ processed["attention_mask"] = processed["attention_mask"][1:]
348
+ processed["labels"] = processed["labels"][1:]
349
+
350
+ all_processed.append(processed)
351
+ all_texts += item["text"]
352
+
353
+ # Handle example indices
354
+ this_num = -1
355
+ if 'example_ind' in item.keys():
356
+ if item["example_ind"] is not None:
357
+ this_num = item["example_ind"]
358
+ example_inds.extend([this_num] * len(processed["input_ids"]))
359
+
360
+ # Handle dataset indices
361
+ dataset_ind = -1
362
+ if "data_id" in item.keys():
363
+ if item["data_id"] is not None:
364
+ dataset_ind = item["data_id"]
365
+ dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
366
+
367
+ # Combine all processed results
368
+ processed = all_processed[0].copy()
369
+ processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
370
+ processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
371
+ processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
372
+ processed["example_inds"] = example_inds
373
+ processed["data_ids"] = dataset_inds
374
+
375
+ # Validate by tokenizing all_texts at once and comparing
376
+ processed_all = self(text=all_texts)
377
+ if len(processed_all["input_ids"]) != len(processed["input_ids"]):
378
+ warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
379
+
380
+ # Generate diff for debugging
381
+ all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
382
+ processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
383
+
384
+ diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
385
+ diff_str = "\n".join(diff)
386
+ print("Diff between texts:")
387
+ print(diff_str)
388
+
389
+ # Token diff
390
+ all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
391
+ processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
392
+ token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
393
+ token_diff_str = "\n".join(token_diff)
394
+ print("Diff between tokenized texts:")
395
+ print(token_diff_str)
396
+
397
+ # Add EOS token if needed
398
+ if (self.eos_token_id is not None and
399
+ processed["input_ids"][-1] != self.eos_token_id):
400
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
401
+ processed["example_inds"] = processed["example_inds"] + [-1]
402
+ processed["attention_mask"] = processed["attention_mask"] + [1]
403
+ if processed["labels"] is not None:
404
+ if loss_on_eos:
405
+ processed["labels"] = processed["labels"] + [self.eos_token_id]
406
+ else:
407
+ processed["labels"] = processed["labels"] + [-100]
408
+ if "data_ids" in processed:
409
+ processed["data_ids"] = processed["data_ids"] + [-1]
410
+
411
+ if not include_eos:
412
+ # check if EOS token is present
413
+ if processed["input_ids"][-1] == self.eos_token_id:
414
+ # remove EOS token
415
+ processed["input_ids"] = processed["input_ids"][:-1]
416
+ processed["attention_mask"] = processed["attention_mask"][:-1]
417
+ processed["labels"] = processed["labels"][:-1]
418
+ processed["example_inds"] = processed["example_inds"][:-1]
419
+ processed["data_ids"] = processed["data_ids"][:-1]
420
+
421
+ return processed
422
+
423
+
424
+
425
+ # Register tokenizer classes for AutoTokenizer
426
+ AutoTokenizer.register("GemmaExplicitTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaExplicitTokenizer)
427
+
428
+ if __name__ == "__main__":
429
+ # Example usage
430
+ # for first load
431
+ custom_tokenizer = GemmaExplicitTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
432
+
433
+ # Test messages in role/content format
434
+ test_messages = [
435
+ [
436
+ {"role": "description", "content": "This is a test task"},
437
+ {"role": "input", "content": "What is 2+2?"},
438
+ {"role": "output", "content": "4"},
439
+ {"role": "input", "content": "What is 3+3?"},
440
+ ],
441
+ [
442
+ {"role": "description", "content": "This is a test task"},
443
+ {"role": "output", "content": "4"},
444
+ {"role": "output", "content": "10"},
445
+ {"role": "output", "content": "13"},
446
+ ],
447
+ [
448
+ {"role": "output", "content": "4"},
449
+ {"role": "output", "content": "10"},
450
+ {"role": "output", "content": "13"},
451
+ ],
452
+ [
453
+ {"role": "input", "content": "What is 2+2?"},
454
+ {"role": "output", "content": "4"},
455
+ {"role": "input", "content": "What is 3+3?"},
456
+ {"role": "output", "content": "10"},
457
+ {"role": "input", "content": "What is 4+4?"},
458
+ ],
459
+ ]
460
+ for messages in test_messages:
461
+ # get messages to text_loss
462
+ texts = custom_tokenizer.messages_to_loss_texts(messages)
463
+
464
+ print("Texts with loss flags:")
465
+ for i, text in enumerate(texts):
466
+ print(f" {i}: {text}")
467
+ processed = custom_tokenizer.tokenize_loss_texts(texts)
468
+ print(f"\nProcessed:")
469
+ print(str(processed["input_ids"][:10]) + "..." + str(processed["input_ids"][-10:]))
470
+
471
+ text = custom_tokenizer.messages_to_text(messages, start_generation=True)
472
+ print(f"\nFull text with generation prompt:")
473
+ print(text)
474
+ # tokenize messages and print input_ids
475
+ processed = custom_tokenizer.tokenize_messages(messages, start_generation=True)
476
+ print(f"\nProcessed:")
477
+ print(str(processed["input_ids"][:10]) + "..." + str(processed["input_ids"][-10:]))
478
+
479
+ # custom_tokenizer.chat_template = CUSTOM_CHAT_TEMPLATE
480
+ # test messages in chat forrmat
481
+ test_messages = [
482
+ [
483
+ {"role": "user", "content": "What is 2+2?"},
484
+ {"role": "assistant", "content": "4"},
485
+ ],
486
+ ]
487
+ chat_text = custom_tokenizer.apply_chat_template(test_messages, tokenize=False)[0]
488
+ print(f"\nChat text:")
489
+ print(chat_text)
490
+
491
+ processed = custom_tokenizer.apply_chat_template(test_messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
492
+ print(f"\nProcessed:")
493
+ print(str(processed[:10]) + "..." + str(processed[-10:]))
494
+ # print(str(processed["input_ids"][:10]) + "..." + str(processed["input_ids"][-10:]))
495
+
496
+ # custom_tokenizer.chat_template = CUSTOM_CHAT_TEMPLATE
497
+ # # test messages in chat forrmat
498
+ # test_messages = [
499
+ # [
500
+ # {"role": "user", "content": "What is 2+2?"},
501
+ # {"role": "assistant", "content": "4"},
502
+ # {"role": "user", "content": "What is 4+2?"},
503
+ # ],
504
+ # ]
505
+ # chat_text = custom_tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)[0]
506
+ # print(f"\nChat text:")
507
+ # print(chat_text)
508
+
509
+ print("\nTesting save/load cycle:")
510
+ # Test saving and loading
511
+ tokenizer_path = "repos/explicit-gemma-tokenizer"
512
+ custom_tokenizer.save_pretrained(tokenizer_path)
513
+ print("Tokenizer saved successfully!")
514
+
515
+ # also save this file in the tokenizer_path
516
+ import shutil
517
+ shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_explicit_tokenizer.py"))
518
+ print("GemmaExplicitTokenizer.py saved successfully!")
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 2,
3
+ "cache_implementation": "hybrid",
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 1,
7
+ 106
8
+ ],
9
+ "pad_token_id": 0,
10
+ "top_k": 64,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.3"
13
+ }
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d51fb9679731ca27310269fe69f1cb6299ef62522bf0f2bf26167b1b74344711
3
+ size 4979902192
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c78a4f6ed87c7f36f1f0ddffded6e5a5bd18d7325b0f50c0f508798b98bf0bca
3
+ size 4931296592
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd088a0ac47373275fcef0ed2810517b9526acc7461f78c25652ff7bc1c982e0
3
+ size 4931296656
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4720033e0908df6f7ce8caeebe39cf545eaf54f5160484f6c51e4b5d141f2e95
3
+ size 4931296656
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32813d3cc6c9eab28bff2dbd4567b8b0284cea30bccd562d49934d5470c6ccf1
3
+ size 4601000928
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_pan_and_scan": null,
5
+ "do_rescale": true,
6
+ "do_resize": true,
7
+ "image_mean": [
8
+ 0.5,
9
+ 0.5,
10
+ 0.5
11
+ ],
12
+ "image_processor_type": "Gemma3ImageProcessor",
13
+ "image_seq_length": 256,
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "pan_and_scan_max_num_crops": null,
20
+ "pan_and_scan_min_crop_size": null,
21
+ "pan_and_scan_min_ratio_to_activate": null,
22
+ "processor_class": "Gemma3Processor",
23
+ "resample": 2,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "height": 896,
27
+ "width": 896
28
+ }
29
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e18ebe474e27ae2b0da6c8f30ffd3da26bb1dcb4fb887f876a9f35764d257709
3
+ size 7377