tsor13 commited on
Commit
e0c9819
·
verified ·
1 Parent(s): c20ba7e

Initial upload of fine‑tuned Gemma + custom tokenizer

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3ForConditionalGeneration"
4
+ ],
5
+ "boi_token_index": 255999,
6
+ "eoi_token_index": 256000,
7
+ "eos_token_id": [
8
+ 1,
9
+ 106
10
+ ],
11
+ "image_token_index": 262144,
12
+ "initializer_range": 0.02,
13
+ "mm_tokens_per_image": 256,
14
+ "model_type": "gemma3",
15
+ "text_config": {
16
+ "attention_bias": false,
17
+ "attention_dropout": 0.0,
18
+ "attn_logit_softcapping": null,
19
+ "cache_implementation": "hybrid",
20
+ "final_logit_softcapping": null,
21
+ "head_dim": 256,
22
+ "hidden_activation": "gelu_pytorch_tanh",
23
+ "hidden_size": 3840,
24
+ "initializer_range": 0.02,
25
+ "intermediate_size": 15360,
26
+ "max_position_embeddings": 131072,
27
+ "model_type": "gemma3_text",
28
+ "num_attention_heads": 16,
29
+ "num_hidden_layers": 48,
30
+ "num_key_value_heads": 8,
31
+ "query_pre_attn_scalar": 256,
32
+ "rms_norm_eps": 1e-06,
33
+ "rope_local_base_freq": 10000.0,
34
+ "rope_scaling": {
35
+ "factor": 8.0,
36
+ "rope_type": "linear"
37
+ },
38
+ "rope_theta": 1000000.0,
39
+ "sliding_window": 1024,
40
+ "sliding_window_pattern": 6,
41
+ "torch_dtype": "float32",
42
+ "use_cache": true,
43
+ "vocab_size": 262208
44
+ },
45
+ "torch_dtype": "bfloat16",
46
+ "transformers_version": "4.51.3",
47
+ "vision_config": {
48
+ "attention_dropout": 0.0,
49
+ "hidden_act": "gelu_pytorch_tanh",
50
+ "hidden_size": 1152,
51
+ "image_size": 896,
52
+ "intermediate_size": 4304,
53
+ "layer_norm_eps": 1e-06,
54
+ "model_type": "siglip_vision_model",
55
+ "num_attention_heads": 16,
56
+ "num_channels": 3,
57
+ "num_hidden_layers": 27,
58
+ "patch_size": 14,
59
+ "torch_dtype": "float32",
60
+ "vision_use_head": false
61
+ }
62
+ }
gemma_chat_tokenizer.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Gemma Tokenizer for chat Format
3
+
4
+ This tokenizer implements the chat format for message processing:
5
+ Format: Uses the standard chat template with proper role labels (user/assistant)
6
+
7
+ The chat format uses the model's built-in chat template and includes proper
8
+ loss computation flags for training with "assistant" as the generation role.
9
+
10
+ To save:
11
+ uv run tokenizers/gemma_chat_tokenizer.py
12
+ which will save the tokenizer to the repos/chat-gemma-tokenizer directory.
13
+ mkdir repos/chat12b
14
+ # copy model over
15
+ cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/
16
+ # copy tokenizer over
17
+ cp repos/chat-gemma-tokenizer/* repos/chat12b/
18
+ # upload to hf
19
+
20
+ uv run upload_to_hf.py \
21
+ --folder repos/chat12b \
22
+ --repo-id tsor13/chat12b
23
+ """
24
+
25
+ from typing import List, Dict, Any, Optional, Union
26
+ from transformers import AutoTokenizer
27
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
28
+ from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
29
+ import warnings
30
+ import difflib
31
+ import json
32
+ import os
33
+ import sys
34
+
35
+
36
+ class GemmaChatTokenizer(GemmaTokenizerFast):
37
+ """
38
+ Custom tokenizer for Gemma models that implements chat format message processing.
39
+
40
+ This tokenizer formats messages using the chat format where:
41
+ - Messages use the standard chat template with proper role labels
42
+ - Uses the model's built-in chat formatting
43
+ - Loss is computed on the assistant sections (not output)
44
+
45
+ Attributes:
46
+ start_string (str): The starting string used for output generation (depends on tokenizer)
47
+ end_string (str): The ending string used for output generation (depends on tokenizer)
48
+ """
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ """
52
+ Initialize the custom tokenizer.
53
+
54
+ Accepts the same arguments as GemmaTokenizerFast.
55
+ """
56
+ super().__init__(*args, **kwargs)
57
+
58
+ # For chat format, we use the tokenizer's own chat template
59
+ # The start/end strings will be determined by the chat template
60
+ self.start_string = "<start_of_turn>" # Will be set dynamically
61
+ self.end_string = "<end_of_turn>" # Will be set dynamically
62
+
63
+ # Add custom attributes to the tokenizer config for saving/loading
64
+ if not hasattr(self, 'init_kwargs'):
65
+ self.init_kwargs = {}
66
+ self.init_kwargs['start_string'] = self.start_string
67
+ self.init_kwargs['end_string'] = self.end_string
68
+
69
+ @classmethod
70
+ def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
71
+ """
72
+ Load a tokenizer from a pretrained model or path.
73
+
74
+ This method ensures our custom class is used instead of the base GemmaTokenizerFast.
75
+ """
76
+ # Load the base tokenizer first to get all configuration
77
+ base_tokenizer = GemmaTokenizerFast.from_pretrained(
78
+ pretrained_model_name_or_path, *args, **kwargs
79
+ )
80
+
81
+ # Create new instance of our custom class by copying the base tokenizer
82
+ custom_tokenizer = cls.__new__(cls)
83
+
84
+ # Copy all attributes from base tokenizer
85
+ for attr, value in base_tokenizer.__dict__.items():
86
+ setattr(custom_tokenizer, attr, value)
87
+
88
+ # Initialize our custom attributes for chat format
89
+ custom_tokenizer.start_string = "<start_of_turn>"
90
+ custom_tokenizer.end_string = "<end_of_turn>"
91
+
92
+ # Update init_kwargs to include our custom attributes
93
+ if not hasattr(custom_tokenizer, 'init_kwargs'):
94
+ custom_tokenizer.init_kwargs = {}
95
+ custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
96
+ custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
97
+
98
+ return custom_tokenizer
99
+
100
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
101
+ """
102
+ Save the tokenizer to a directory, including custom configuration.
103
+ """
104
+ # Call parent save method
105
+ super().save_pretrained(save_directory, **kwargs)
106
+
107
+ # Save custom configuration
108
+ config_file = os.path.join(save_directory, "tokenizer_config.json")
109
+ if os.path.exists(config_file):
110
+ with open(config_file, 'r') as f:
111
+ config = json.load(f)
112
+ else:
113
+ config = {}
114
+
115
+ # Add our custom class info
116
+ config["tokenizer_class"] = "GemmaChatTokenizer"
117
+ config["start_string"] = self.start_string
118
+ config["end_string"] = self.end_string
119
+ # Point to our custom class in the uploaded file
120
+ config["auto_map"] = {
121
+ "AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"]
122
+ }
123
+
124
+ with open(config_file, 'w') as f:
125
+ json.dump(config, f, indent=2)
126
+
127
+ def messages_to_chat_messages(
128
+ self,
129
+ messages: List[Dict[str, Any]],
130
+ start_generation: bool = False,
131
+ default_user_message: str = "Generate.",
132
+ ) -> List[Dict[str, Any]]:
133
+ """
134
+ From messages (description / input / output) to chat messages (role / content)
135
+ Uses the same logic as chat_utils.py with system messages.
136
+ """
137
+ chat_prompt = """You are tasked with generating outputs from a particular, potentially stochastic, generative process. You will be given some combination of:
138
+ - Description: A natural description of the generative process / data distribution
139
+ - Input: An input on which to condition the generative process.
140
+ - Example outputs: Example outputs from the process, either in a user message or as prior generations from a chat message. You may assume that any given outputs are exchangeable with one another (order-invariant) and generated from the same process (roughly i.i.d.). If the output data pertains to a single object, it just contains the output. If it contains multiple objects, use json formatting with keys for the name of the output variable.
141
+ You will be provided at least either a description or an example output.
142
+
143
+ Given these components, your job is to generate JUST the output in your response, roughly approximating the underlying generative process, maintaining any underlying stochasticity (if any is present). If you are asked to generate again, you will either be given an additional input to condition on, or will just be told to "Generate"."""
144
+
145
+ chat_messages = []
146
+ system_message = chat_prompt
147
+ has_description_or_output = False
148
+ has_input = False
149
+
150
+ for message in messages:
151
+ if message["role"] == "description":
152
+ system_message += "\n\nDescription: " + message["content"]
153
+ chat_messages.append({"role": "system", "content": system_message})
154
+ has_description_or_output = True
155
+ elif message["role"] == "input":
156
+ has_input = True
157
+ if not has_description_or_output:
158
+ system_message += "\n\nExample Input: " + message["content"]
159
+ else:
160
+ chat_messages.append({"role": "user", "content": message["content"]})
161
+ elif message["role"] == "output":
162
+ if not has_description_or_output:
163
+ system_message += "\n\nExample Output: " + message["content"]
164
+ chat_messages.append({"role": "system", "content": system_message})
165
+ has_description_or_output = True
166
+ else:
167
+ if not has_input:
168
+ chat_messages.append({"role": "user", "content": default_user_message})
169
+ chat_messages.append({"role": "assistant", "content": message["content"]})
170
+
171
+ if len(chat_messages) == 0:
172
+ # add system message
173
+ chat_messages.append({"role": "system", "content": system_message})
174
+ # also add in empty user message for now for gemma
175
+ chat_messages.append({"role": "user", "content": ""})
176
+ if len(chat_messages) == 1:
177
+ # add in empty user message for now for gemma
178
+ chat_messages.append({"role": "user", "content": ""})
179
+
180
+ # if last message is output and start_generation is true, add a default user message
181
+ if start_generation and chat_messages[-1]["role"] == "assistant":
182
+ chat_messages.append({"role": "user", "content": default_user_message})
183
+
184
+ return chat_messages
185
+
186
+
187
+
188
+ def messages_to_loss_texts(
189
+ self,
190
+ messages: List[Dict[str, Any]],
191
+ loss_on_start_token: bool = False,
192
+ default_user_message: str = "Generate.",
193
+ start_generation: bool = False,
194
+ ) -> List[Dict[str, Any]]:
195
+ """
196
+ From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
197
+ Uses the chat format matching chat_utils.py with updated loss computation logic.
198
+ """
199
+ # FOR NOW, OVERRIDING TO FALSE
200
+ loss_on_start_token = False
201
+
202
+ texts = []
203
+
204
+ chat_messages = self.messages_to_chat_messages(messages, start_generation=start_generation, default_user_message=default_user_message)
205
+
206
+ # Apply chat template
207
+ full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation)
208
+ # replace <bos> with nothing
209
+ full_text = full_text.replace("<bos>", "")
210
+
211
+ text_to_split = full_text
212
+ # now, find all places starting with <start_of_turn>model\n
213
+ model_start_text = "<start_of_turn>model\n" # TODO - manual for now, change later
214
+ first = True
215
+ while model_start_text in text_to_split:
216
+ # get location of model_start_text
217
+ model_start_loc = text_to_split.find(model_start_text)
218
+ split_ind = model_start_loc + len(model_start_text)
219
+ text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:]
220
+ # add to texts
221
+ texts.append({"text": text_to_add, "compute_loss": False})
222
+ # get location of end_string
223
+ end_string_loc = text_to_split.find(self.end_string)
224
+ end_ind = end_string_loc + len(self.end_string)
225
+ text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:]
226
+ # Calculate loss on ALL assistant messages (removed conditional logic)
227
+ texts.append({"text": text_to_add, "compute_loss": True})
228
+ first = False
229
+ if len(text_to_split) > 0:
230
+ texts.append({"text": text_to_split, "compute_loss": False})
231
+ if len(texts) == 0:
232
+ breakpoint()
233
+
234
+ return texts
235
+
236
+ def messages_to_text(
237
+ self,
238
+ messages: List[Dict[str, Any]],
239
+ start_generation: bool = False,
240
+ ) -> str:
241
+ """
242
+ Messages (description / input / output) to raw text (text).
243
+ Uses the chat format matching chat_utils.py.
244
+ """
245
+ texts = self.messages_to_loss_texts(messages, start_generation=start_generation)
246
+ text = "".join([text["text"] for text in texts])
247
+ return text
248
+
249
+
250
+ def tokenize_messages(
251
+ self,
252
+ messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
253
+ start_generation: bool = False,
254
+ **kwargs,
255
+ ):
256
+ """
257
+ For tokenizing from messages to texts. Supports batching. Good for generation
258
+ """
259
+ if isinstance(messages, list) and isinstance(messages[0], list):
260
+ # Handle list of lists of messages
261
+ all_texts = []
262
+ for message_list in messages:
263
+ texts = self.messages_to_text(message_list, start_generation)
264
+ all_texts.append(texts)
265
+ else:
266
+ # Handle single list of messages
267
+ texts = self.messages_to_text(messages, start_generation)
268
+ all_texts = [texts]
269
+
270
+ # Tokenize all texts
271
+ processed = self(text=all_texts, **kwargs)
272
+ return processed
273
+
274
+
275
+ def tokenize_loss_texts(
276
+ self,
277
+ texts: List[Dict[str, Any]],
278
+ loss_on_start_token: bool = False,
279
+ loss_on_eos: bool = False,
280
+ include_eos: bool = True,
281
+ ):
282
+ """
283
+ Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
284
+
285
+ Needs more complex logic to handle the back and forth labeling.
286
+ """
287
+ if loss_on_eos:
288
+ raise ValueError("Loss on EOS is not currently supported.")
289
+
290
+ # Handle single string input
291
+ if isinstance(texts, str):
292
+ processed = self(text=texts)
293
+ # Add EOS token if needed
294
+ if (self.eos_token_id is not None and
295
+ processed["input_ids"][-1] != self.eos_token_id):
296
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
297
+ processed["attention_mask"] = processed["attention_mask"] + [1]
298
+ return processed
299
+
300
+ # Handle list of text dictionaries
301
+ all_processed = []
302
+ all_texts = ''
303
+ example_inds = []
304
+ dataset_inds = []
305
+
306
+
307
+ for i, item in enumerate(texts):
308
+ processed = self(text=item["text"])
309
+
310
+ # Remove BOS token from all but first item
311
+ if i != 0 and self.bos_token_id == processed["input_ids"][0]:
312
+ processed["input_ids"] = processed["input_ids"][1:]
313
+ processed["attention_mask"] = processed["attention_mask"][1:]
314
+
315
+ # Remove EOS token if present at the end
316
+ if processed["input_ids"][-1] == self.eos_token_id:
317
+ processed["input_ids"] = processed["input_ids"][:-1]
318
+ processed["attention_mask"] = processed["attention_mask"][:-1]
319
+
320
+ # Check for EOS token in the middle (with special handling for <|im_end|>)
321
+ if self.eos_token_id in processed["input_ids"]:
322
+ if not self.decode([self.eos_token_id]) == "<|im_end|>":
323
+ raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
324
+
325
+ # Set labels based on compute_loss flag
326
+ if item["compute_loss"]:
327
+ processed["labels"] = processed["input_ids"].copy()
328
+ else:
329
+ processed["labels"] = [-100] * len(processed["input_ids"])
330
+
331
+ # Remove duplicate BOS tokens
332
+ if all_processed:
333
+ if processed["input_ids"][0] == self.bos_token_id:
334
+ processed["input_ids"] = processed["input_ids"][1:]
335
+ processed["attention_mask"] = processed["attention_mask"][1:]
336
+ processed["labels"] = processed["labels"][1:]
337
+
338
+ all_processed.append(processed)
339
+ all_texts += item["text"]
340
+
341
+ # Handle example indices
342
+ this_num = -1
343
+ if 'example_ind' in item.keys():
344
+ if item["example_ind"] is not None:
345
+ this_num = item["example_ind"]
346
+ example_inds.extend([this_num] * len(processed["input_ids"]))
347
+
348
+ # Handle dataset indices
349
+ dataset_ind = -1
350
+ if "data_id" in item.keys():
351
+ if item["data_id"] is not None:
352
+ dataset_ind = item["data_id"]
353
+ dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
354
+
355
+ # Combine all processed results
356
+ processed = all_processed[0].copy()
357
+ processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
358
+ processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
359
+ processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
360
+ processed["example_inds"] = example_inds
361
+ processed["data_ids"] = dataset_inds
362
+
363
+ # Validate by tokenizing all_texts at once and comparing
364
+ processed_all = self(text=all_texts)
365
+ if len(processed_all["input_ids"]) != len(processed["input_ids"]):
366
+ warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
367
+
368
+ # Generate diff for debugging
369
+ all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
370
+ processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
371
+
372
+ diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
373
+ diff_str = "\n".join(diff)
374
+ print("Diff between texts:")
375
+ print(diff_str)
376
+
377
+ # Token diff
378
+ all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
379
+ processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
380
+ token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
381
+ token_diff_str = "\n".join(token_diff)
382
+ print("Diff between tokenized texts:")
383
+ print(token_diff_str)
384
+
385
+ # Add EOS token if needed
386
+ if (self.eos_token_id is not None and
387
+ processed["input_ids"][-1] != self.eos_token_id):
388
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
389
+ processed["example_inds"] = processed["example_inds"] + [-1]
390
+ processed["attention_mask"] = processed["attention_mask"] + [1]
391
+ if processed["labels"] is not None:
392
+ if loss_on_eos:
393
+ processed["labels"] = processed["labels"] + [self.eos_token_id]
394
+ else:
395
+ processed["labels"] = processed["labels"] + [-100]
396
+ if "data_ids" in processed:
397
+ processed["data_ids"] = processed["data_ids"] + [-1]
398
+
399
+ if not include_eos:
400
+ # check if EOS token is present
401
+ if processed["input_ids"][-1] == self.eos_token_id:
402
+ # remove EOS token
403
+ processed["input_ids"] = processed["input_ids"][:-1]
404
+ processed["attention_mask"] = processed["attention_mask"][:-1]
405
+ processed["labels"] = processed["labels"][:-1]
406
+ processed["example_inds"] = processed["example_inds"][:-1]
407
+ processed["data_ids"] = processed["data_ids"][:-1]
408
+
409
+ return processed
410
+
411
+ def tokenize_messages(
412
+ self,
413
+ messages: List[Dict[str, Any]],
414
+ loss_on_start_token: bool = False,
415
+ loss_on_eos: bool = False,
416
+ include_eos: bool = True,
417
+ ) -> Dict[str, Any]:
418
+ """
419
+ Intended for tokenize from messages to tokenized texts with the loss applied.
420
+ """
421
+ # First convert messages to text with loss computation flags
422
+ texts = self.messages_to_loss_texts(messages, loss_on_start_token)
423
+
424
+ # Then tokenize the texts
425
+ return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos)
426
+
427
+
428
+
429
+
430
+ # Register tokenizer classes for AutoTokenizer
431
+ AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer)
432
+
433
+
434
+ if __name__ == "__main__":
435
+ # Example usage
436
+ # for first load
437
+ custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
438
+
439
+ # for subsequent loads
440
+ # custom_tokenizer = GemmaChatTokenizer.from_pretrained("tsor13/chat-gemma-12b-pt")
441
+ # custom_tokenizer = GemmaChatTokenizer.from_pretrained("repos/chat-gemma-12b-pt")
442
+
443
+ # Test messages in role/content format
444
+ test_messages = [
445
+ [
446
+ {"role": "description", "content": "Pick a number between 1 and 100"},
447
+ ],
448
+
449
+ [
450
+ {"role": "description", "content": "This is a test task"},
451
+ {"role": "input", "content": "What is 2+2?"},
452
+ {"role": "output", "content": "4"},
453
+ {"role": "input", "content": "What is 3+3?"},
454
+ ],
455
+ [
456
+ {"role": "description", "content": "This is a test task"},
457
+ {"role": "output", "content": "4"},
458
+ {"role": "output", "content": "10"},
459
+ {"role": "output", "content": "13"},
460
+ ],
461
+ [
462
+ {"role": "output", "content": "4"},
463
+ {"role": "output", "content": "10"},
464
+ {"role": "output", "content": "13"},
465
+ ],
466
+ [
467
+ {"role": "input", "content": "What is 2+2?"},
468
+ {"role": "output", "content": "4"},
469
+ {"role": "input", "content": "What is 3+3?"},
470
+ {"role": "output", "content": "10"},
471
+ {"role": "input", "content": "What is 4+4?"},
472
+ ],
473
+ [
474
+ {"role": "description", "content": "DESCRIPTION"},
475
+ {"role": "input", "content": "INPUT1"},
476
+ {"role": "output", "content": "OUTPUT1"},
477
+ {"role": "input", "content": "INPUT2"},
478
+ {"role": "output", "content": "OUTPUT2"},
479
+ ],
480
+ [
481
+ {"role": "description", "content": "DESCRIPTION"},
482
+ {"role": "output", "content": "OUTPUT1"},
483
+ {"role": "output", "content": "OUTPUT2"},
484
+ ],
485
+ ]
486
+ for messages in test_messages:
487
+ # get messages to text_loss
488
+ texts = custom_tokenizer.messages_to_loss_texts(messages)
489
+
490
+ print("Texts with loss flags:")
491
+ for i, text in enumerate(texts):
492
+ print(f" {i}: {text}")
493
+
494
+ text = custom_tokenizer.messages_to_text(messages, start_generation=True)
495
+ print(f"\nFull text with generation prompt:")
496
+ print(text)
497
+
498
+
499
+ print("\nTesting save/load cycle:")
500
+ # Test saving and loading
501
+ tokenizer_path = "repos/chat-gemma-tokenizer"
502
+ custom_tokenizer.save_pretrained(tokenizer_path)
503
+ print("Tokenizer saved successfully!")
504
+
505
+ # also save this file in the tokenizer_path
506
+ import shutil
507
+ shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py"))
508
+ print("GemmaChatTokenizer.py saved successfully!")
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 2,
3
+ "cache_implementation": "hybrid",
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 1,
7
+ 106
8
+ ],
9
+ "pad_token_id": 0,
10
+ "top_k": 64,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.3"
13
+ }
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:767d7a9fea55e6c77a497deecbe9a96f956f1e97b0e48213f4978b94e4043eb2
3
+ size 4979902192
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f2b5baaa7eea2abffcf0e846e58d4b84c10c6a15cdbbd11a67072e4522a49f1
3
+ size 4931296592
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04e9e018d3d0757eb516c9acc1e5ff203a944ecb1a24d9f0d6e3e08a4af75b11
3
+ size 4931296656
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:702c672a6a1951b6fc81dedb4276ea8e6f146abb0d041aaeec6a0a042c8beb3b
3
+ size 4931296656
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e00779aa228c203f949575ec7aba50f81deec0ea80c4ecbf0be8c226c89c15f1
3
+ size 4601000928
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:267c12b1809a697fee02ca2efc4f0dd076c1f378edc40b9163dacfb3f1028db9
3
+ size 7313