tsor13 commited on
Commit
8b9b649
·
verified ·
1 Parent(s): cb2bfff

Initial upload of fine‑tuned Gemma + custom tokenizer

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### tsor13/Special12b
2
+ The following is a a model trained by [...suspense...] that is meant to:
3
+ - follow instructions better than pretrained models and be more diverse / less mode-collapsed than instruct models;
4
+ - be a really good, approximately bayesian in-context learner;
5
+ - fit an data generation process
6
+ - be calibrated over distributions of possible outputs wrt a population or epistemic uncertainty
7
+ It is initialized from `google/gemma-3-12b-pt`.
8
+
9
+ This model/repo is a work in progress - expect updates.
10
+
11
+ Loading model example:
12
+ ```
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ tokenizer = AutoTokenizer.from_pretrained("tsor13/special12b", trust_remote_code=True) # custom tokenizer for handling messages / loss
15
+ model = AutoModelForCausalLM.from_pretrained("tsor13/special12b", device_map="auto")
16
+ ```
17
+
18
+ It has its own chat-style input messages, with the following roles:
19
+ - `description`(optional): A description of the generating process, or some information meant to instantiate a prior
20
+ - `input` (optional): Any variables that a model is not responsible for predicting, but could be used to condition generation somehow;
21
+ - `output`: This is what the model will actually predict / generate.
22
+
23
+ For example,
24
+ ```
25
+ messages = [
26
+ {"role": "description", "content": "Capitals"},
27
+ {"role": "input", "content": "France"},
28
+ {"role": "output", "content": "Paris"},
29
+ {"role": "input", "content": "Japan"},
30
+ ]
31
+ ```
32
+
33
+ To templatize the messages, you can use the tokenizer:
34
+ ```
35
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
36
+ print(formatted_prompt) # start_generation adds the <start_of_turn> token to condition the model for generation
37
+ ```
38
+ Output:
39
+ ```
40
+ Capitals
41
+ France
42
+ <start_of_turn>Paris<end_of_turn>
43
+ Japan
44
+ <start_of_turn>
45
+ ```
46
+ The data for the model to emulate / generate is wrapped in `<start_of_turn>` / `<end_of_turn>` tokens.
47
+ Description and input is not wrapped in anything. Thus, do not expect the model to generate these tokens - instead focus on the wrapped output tokens.
48
+ Messages are separated by newlines.
49
+
50
+ In training, loss is ONLY calculated on the output tokens and the `<end_of_turn>` token. Thus, the model is only designed to generate / predict probabilities after `<start_of_turn>` and until `<end_of_turn>` - everything else is out of distribution for the model and not recommended.
51
+
52
+ Once you have the formatted text, you can tokenize as normal:
53
+ ```
54
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
55
+ ```
56
+
57
+ Let's look at what the model does. In this case, there is a single correct answer. Let's look at model probabilities after `<start_of_turn>`:
58
+ ```
59
+ import torch
60
+ with torch.no_grad():
61
+ output = model(**inputs)
62
+ logits = output.logits[0, -1, :]
63
+ probs = torch.nn.functional.softmax(logits, dim=-1)
64
+ top_probs, top_indices = torch.topk(probs, 10)
65
+ print("\nTop 10 probabilities for first output token:")
66
+ for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
67
+ token = tokenizer.decode(idx)
68
+ print(f"{i+1:2d}. '{token}' -> {prob.item():.4f}")
69
+ ```
70
+
71
+ Output:
72
+ ```
73
+ Top 10 probabilities for first output token:
74
+ 1. 'Tokyo' -> 0.9764
75
+ 2. 'Tok' -> 0.0070
76
+ 3. '東京' -> 0.0026
77
+ 4. 'Ky' -> 0.0019
78
+ 5. 'T' -> 0.0014
79
+ 6. ' Tokyo' -> 0.0014
80
+ 7. 'To' -> 0.0011
81
+ 8. 'Osaka' -> 0.0009
82
+ 9. 'Toy' -> 0.0007
83
+ 10. 'tok' -> 0.0005
84
+ ```
85
+
86
+ Great! Almost all of the probability mass is on the correct answer, Tokyo.
87
+
88
+ Let's try an example with many possible reasonable choices / a harder to describe distribution. For example, say that I'm interested in modeling "board games that I like". I may be hard-pressed to actually describe what it is that I like about games - but I could provide a few examples pretty easily.
89
+
90
+ ```
91
+ messages = [
92
+ {"role": "output", "content": "Dune: Imperium"},
93
+ {"role": "output", "content": "Acquire"},
94
+ {"role": "output", "content": "Catan"},
95
+ {"role": "output", "content": "Tigris and Euphrates"},
96
+ {"role": "output", "content": "Brass: Birmingham"},
97
+ ]
98
+ ```
99
+
100
+ Given these example outputs, the model will try to generate more outputs like these outputs.
101
+ ```
102
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
103
+ n_gens = 4
104
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
105
+
106
+ outputs = model.generate(**inputs, max_new_tokens=10, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
107
+ for i in range(n_gens):
108
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
109
+ ```
110
+
111
+ Outputs:
112
+ ```
113
+ Catan: Rivals for Catan
114
+ Gloomhaven
115
+ Great Western Trail
116
+ Azul
117
+ ```
118
+ Not too bad!
119
+
120
+ You can also specify just the description:
121
+ Input:
122
+ ```
123
+ messages = [
124
+ {"role": "description", "content": "Descriptive colors"},
125
+ ]
126
+
127
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
128
+ n_gens = 4
129
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
130
+
131
+ outputs = model.generate(**inputs, max_new_tokens=10, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
132
+ for i in range(n_gens):
133
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
134
+ print()
135
+ ```
136
+ Output:
137
+ ```
138
+ Deep Sea Blue
139
+ Gray#222222
140
+ Gold, Red, Black
141
+ I can’t believe we’re already talking about color theory. How is this possible? Can time go any faster? Also how does your body
142
+ ```
143
+
144
+ By default, the model is only trained to do 1) either emulate outputs if examples are provided, or 2) generate data based on the description. Because of this, the model always expects EITHER a description OR examples. If you want it to act slightly more like an instruction following chat model, you can add a description such as the following:
145
+
146
+ ```
147
+ messages = [
148
+ {"role": "description", "content": "You are a helpful assistant who outputs the requested content."},
149
+ {"role": "input", "content": "A poem about a shark"},
150
+ ]
151
+ ```
152
+ To generate:
153
+ ```
154
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
155
+ n_gens = 4
156
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
157
+
158
+ outputs = model.generate(**inputs, max_new_tokens=40, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
159
+ for i in range(n_gens):
160
+ print(f"Generation {i}:")
161
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
162
+ ```
163
+
164
+ Some example generations:
165
+ ```
166
+ Generation 0:
167
+ A deep-sea creature, silent and fierce, Shivers through water, its body sleek. Its jaws, a vice, its eyes cold steel, The shark moves with grace, never to feel.
168
+ of power and danger,
169
+ Generation 1:
170
+ The great white shark lurks in the deep, with teeth so sharp, it could cut a whale in half. Its dorsal fin slices through the water, like a knife through butter, and its tail
171
+ Generation 2:
172
+ The shark swam in the sea, With a toothy grin, as if it could be glee. It was the top of the food chain, The apex of the sea's terrain. With sleek
173
+ Generation 3:
174
+ I am a gentle, tranquil wave, gliding smoothly across the ocean's expanse. Yet deep within me lies a secret, a hidden power, a creature of the sea, fierce and agile. It
175
+ ```
176
+
177
+
178
+
179
+
180
+ Finally, let's look at a synthetic data generation task. For example, maybe we want to generate situations to do social reasoning over, along with whether or not they are awkward. When there are multiple variables to condition on or generat, the model is used to json format.
181
+
182
+ Input:
183
+ ```
184
+ import json
185
+ messages = [
186
+ {"role": "description", "content": "Situations to do social reasoning over, along with whether or not it is an awkward situation."},
187
+ {"role": "output", "content": json.dumps({
188
+ "situation": "You're at a party and you realize that your shirt is on backwards.",
189
+ "is_awkward": True,
190
+ })},
191
+ {"role": "output", "content": json.dumps({
192
+ "situation": "While at work, your boss commends you on a job well done.",
193
+ "is_awkward": False,
194
+ })},
195
+ {"role": "output", "content": json.dumps({
196
+ "situation": "Realizing you forgot to bring your passport to the airport.",
197
+ "is_awkward": True,
198
+ })},
199
+ ]
200
+
201
+ formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
202
+ n_gens = 4
203
+ inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
204
+
205
+ outputs = model.generate(**inputs, max_new_tokens=40, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
206
+ for i in range(n_gens):
207
+ print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
208
+ ```
209
+ Output:
210
+ ```
211
+ {"situation": "While walking on the street, someone waves and smiles at you, but you don't know them.", "is_awkward": false}
212
+ {"situation": "Taking a cab and giving the driver wrong directions.", "is_awkward": true}
213
+ {"situation": "Being told that an individual you've had a long-term crush on is also crushing on someone else.", "is_awkward": true}
214
+ {"situation": "Watching a loved one get proposed to.", "is_awkward": false}
215
+ ```
216
+
217
+ A few tips and tricks:
218
+ - Do not expect the model to do multi-turn chats. It is designed to be stateless and to treat each data point as "exchangeable" (roughly iid).
219
+ - If all you want is one reasonable answer, then a chat model is likely a better fit. However, if you want to generate many reasonable answers / diverse examples, this model is a better fit.
220
+ - The model is quite good at perspective taking / steering if you provide many examples.
221
+ - The model is reasonably good at expressing epistemic uncertainty over unsure outputs by sampling several times.
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3ForConditionalGeneration"
4
+ ],
5
+ "boi_token_index": 255999,
6
+ "eoi_token_index": 256000,
7
+ "image_token_index": 262144,
8
+ "initializer_range": 0.02,
9
+ "mm_tokens_per_image": 256,
10
+ "model_type": "gemma3",
11
+ "text_config": {
12
+ "attention_bias": false,
13
+ "attention_dropout": 0.0,
14
+ "attn_logit_softcapping": null,
15
+ "cache_implementation": "hybrid",
16
+ "final_logit_softcapping": null,
17
+ "head_dim": 256,
18
+ "hidden_activation": "gelu_pytorch_tanh",
19
+ "hidden_size": 3840,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 15360,
22
+ "max_position_embeddings": 131072,
23
+ "model_type": "gemma3_text",
24
+ "num_attention_heads": 16,
25
+ "num_hidden_layers": 48,
26
+ "num_key_value_heads": 8,
27
+ "query_pre_attn_scalar": 256,
28
+ "rms_norm_eps": 1e-06,
29
+ "rope_local_base_freq": 10000.0,
30
+ "rope_scaling": {
31
+ "factor": 8.0,
32
+ "rope_type": "linear"
33
+ },
34
+ "rope_theta": 1000000.0,
35
+ "sliding_window": 1024,
36
+ "sliding_window_pattern": 6,
37
+ "torch_dtype": "float32",
38
+ "use_cache": true,
39
+ "vocab_size": 262208
40
+ },
41
+ "torch_dtype": "bfloat16",
42
+ "transformers_version": "4.51.3",
43
+ "vision_config": {
44
+ "attention_dropout": 0.0,
45
+ "hidden_act": "gelu_pytorch_tanh",
46
+ "hidden_size": 1152,
47
+ "image_size": 896,
48
+ "intermediate_size": 4304,
49
+ "layer_norm_eps": 1e-06,
50
+ "model_type": "siglip_vision_model",
51
+ "num_attention_heads": 16,
52
+ "num_channels": 3,
53
+ "num_hidden_layers": 27,
54
+ "patch_size": 14,
55
+ "torch_dtype": "float32",
56
+ "vision_use_head": false
57
+ }
58
+ }
gemma_explicit_tokenizer.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Gemma Tokenizer for explicit Format
3
+
4
+ This tokenizer implements the explicit format for message processing:
5
+ Format: Uses the standard chat template with proper role labels (user/assistant)
6
+
7
+ The explicit format uses the model's built-in chat template and includes proper
8
+ loss computation flags for training.
9
+
10
+ To save:
11
+ uv run tokenizers/gemma_explicit_tokenizer.py
12
+ which will save the tokenizer to the repos/explicit-gemma-tokenizer directory.
13
+ mkdir repos/explicit12b
14
+ # copy model over
15
+ cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_explicit/checkpoint-8/* repos/explicit12b/
16
+ # copy tokenizer over
17
+ cp repos/explicit-gemma-tokenizer/* repos/explicit12b/
18
+ # upload to hf
19
+
20
+ uv run upload_to_hf.py \
21
+ --folder repos/explicit12b \
22
+ --repo-id tsor13/explicit12b
23
+ """
24
+
25
+ from typing import List, Dict, Any, Optional, Union
26
+ from transformers import AutoTokenizer
27
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
28
+ from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
29
+ import warnings
30
+ import difflib
31
+ import json
32
+ import os
33
+ import sys
34
+
35
+ # Add parent directory to path to import chat_utils
36
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
37
+ from chat_utils import chat_messages_to_text_loss, chat_messages_to_raw_text
38
+
39
+
40
+ class GemmaExplicitTokenizer(GemmaTokenizerFast):
41
+ """
42
+ Custom tokenizer for Gemma models that implements explicit format message processing.
43
+
44
+ This tokenizer formats messages using the explicit format where:
45
+ - Messages use the standard chat template with proper role labels
46
+ - Uses the model's built-in chat formatting
47
+ - Loss is computed on the assistant/output sections
48
+
49
+ Attributes:
50
+ start_string (str): The starting string used for output generation (depends on tokenizer)
51
+ end_string (str): The ending string used for output generation (depends on tokenizer)
52
+ """
53
+
54
+ def __init__(self, *args, **kwargs):
55
+ """
56
+ Initialize the custom tokenizer.
57
+
58
+ Accepts the same arguments as GemmaTokenizerFast.
59
+ """
60
+ super().__init__(*args, **kwargs)
61
+
62
+ # For explicit format, we use the tokenizer's own chat template
63
+ # The start/end strings will be determined by the chat template
64
+ self.start_string = None # Will be set dynamically
65
+ self.end_string = None # Will be set dynamically
66
+
67
+ # Add custom attributes to the tokenizer config for saving/loading
68
+ if not hasattr(self, 'init_kwargs'):
69
+ self.init_kwargs = {}
70
+ self.init_kwargs['start_string'] = self.start_string
71
+ self.init_kwargs['end_string'] = self.end_string
72
+
73
+ @classmethod
74
+ def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
75
+ """
76
+ Load a tokenizer from a pretrained model or path.
77
+
78
+ This method ensures our custom class is used instead of the base GemmaTokenizerFast.
79
+ """
80
+ # Load the base tokenizer first to get all configuration
81
+ base_tokenizer = GemmaTokenizerFast.from_pretrained(
82
+ pretrained_model_name_or_path, *args, **kwargs
83
+ )
84
+
85
+ # Create new instance of our custom class by copying the base tokenizer
86
+ custom_tokenizer = cls.__new__(cls)
87
+
88
+ # Copy all attributes from base tokenizer
89
+ for attr, value in base_tokenizer.__dict__.items():
90
+ setattr(custom_tokenizer, attr, value)
91
+
92
+ # Initialize our custom attributes for explicit format
93
+ custom_tokenizer.start_string = None # Will be determined dynamically
94
+ custom_tokenizer.end_string = None # Will be determined dynamically
95
+
96
+ # Update init_kwargs to include our custom attributes
97
+ if not hasattr(custom_tokenizer, 'init_kwargs'):
98
+ custom_tokenizer.init_kwargs = {}
99
+ custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
100
+ custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
101
+
102
+ return custom_tokenizer
103
+
104
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
105
+ """
106
+ Save the tokenizer to a directory, including custom configuration.
107
+ """
108
+ # Call parent save method
109
+ super().save_pretrained(save_directory, **kwargs)
110
+
111
+ # Save custom configuration
112
+ config_file = os.path.join(save_directory, "tokenizer_config.json")
113
+ if os.path.exists(config_file):
114
+ with open(config_file, 'r') as f:
115
+ config = json.load(f)
116
+ else:
117
+ config = {}
118
+
119
+ # Add our custom class info
120
+ config["tokenizer_class"] = "GemmaExplicitTokenizer"
121
+ config["start_string"] = self.start_string
122
+ config["end_string"] = self.end_string
123
+ # Point to our custom class in the uploaded file
124
+ config["auto_map"] = {
125
+ "AutoTokenizer": ["gemma_explicit_tokenizer.GemmaExplicitTokenizer", "gemma_explicit_tokenizer.GemmaExplicitTokenizer"]
126
+ }
127
+
128
+ with open(config_file, 'w') as f:
129
+ json.dump(config, f, indent=2)
130
+
131
+ def messages_to_loss_texts(
132
+ self,
133
+ messages: List[Dict[str, Any]],
134
+ loss_on_start_token: bool = False,
135
+ ) -> List[Dict[str, Any]]:
136
+ """
137
+ From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
138
+ Uses the explicit format from chat_utils.
139
+ """
140
+ return chat_messages_to_text_loss(messages, self, loss_on_start_token, start_gen_as="output")
141
+
142
+ def messages_to_text(
143
+ self,
144
+ messages: List[Dict[str, Any]],
145
+ start_generation: bool = False,
146
+ ) -> str:
147
+ """
148
+ Messages (description / input / output) to raw text (text).
149
+ Uses the explicit format from chat_utils.
150
+ """
151
+ return chat_messages_to_raw_text(messages, self, start_generation=start_generation, start_gen_as="output")
152
+
153
+
154
+ def tokenize_messages(
155
+ self,
156
+ messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
157
+ start_generation: bool = False,
158
+ **kwargs,
159
+ ):
160
+ """
161
+ For tokenizing from messages to texts. Supports batching. Good for generation
162
+ """
163
+ if isinstance(messages, list) and isinstance(messages[0], list):
164
+ # Handle list of lists of messages
165
+ all_texts = []
166
+ for message_list in messages:
167
+ texts = self.messages_to_text(message_list, start_generation)
168
+ all_texts.append(texts)
169
+ else:
170
+ # Handle single list of messages
171
+ texts = self.messages_to_text(messages, start_generation)
172
+ all_texts = [texts]
173
+
174
+ # Tokenize all texts
175
+ processed = self(text=all_texts, **kwargs)
176
+ return processed
177
+
178
+
179
+ def tokenize_loss_texts(
180
+ self,
181
+ texts: List[Dict[str, Any]],
182
+ loss_on_start_token: bool = False,
183
+ loss_on_eos: bool = False,
184
+ include_eos: bool = True,
185
+ ):
186
+ """
187
+ Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
188
+
189
+ Needs more complex logic to handle the back and forth labeling.
190
+ """
191
+ if loss_on_eos:
192
+ raise ValueError("Loss on EOS is not currently supported.")
193
+
194
+ # Handle single string input
195
+ if isinstance(texts, str):
196
+ processed = self(text=texts)
197
+ # Add EOS token if needed
198
+ if (self.eos_token_id is not None and
199
+ processed["input_ids"][-1] != self.eos_token_id):
200
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
201
+ processed["attention_mask"] = processed["attention_mask"] + [1]
202
+ return processed
203
+
204
+ # Handle list of text dictionaries
205
+ all_processed = []
206
+ all_texts = ''
207
+ example_inds = []
208
+ dataset_inds = []
209
+
210
+ for i, item in enumerate(texts):
211
+ processed = self(text=item["text"])
212
+
213
+ # Remove BOS token from all but first item
214
+ if i != 0 and self.bos_token_id == processed["input_ids"][0]:
215
+ processed["input_ids"] = processed["input_ids"][1:]
216
+ processed["attention_mask"] = processed["attention_mask"][1:]
217
+
218
+ # Remove EOS token if present at the end
219
+ if processed["input_ids"][-1] == self.eos_token_id:
220
+ processed["input_ids"] = processed["input_ids"][:-1]
221
+ processed["attention_mask"] = processed["attention_mask"][:-1]
222
+
223
+ # Check for EOS token in the middle (with special handling for <|im_end|>)
224
+ if self.eos_token_id in processed["input_ids"]:
225
+ if not self.decode([self.eos_token_id]) == "<|im_end|>":
226
+ raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
227
+
228
+ # Set labels based on compute_loss flag
229
+ if item["compute_loss"]:
230
+ processed["labels"] = processed["input_ids"].copy()
231
+ else:
232
+ processed["labels"] = [-100] * len(processed["input_ids"])
233
+
234
+ # Remove duplicate BOS tokens
235
+ if all_processed:
236
+ if processed["input_ids"][0] == self.bos_token_id:
237
+ processed["input_ids"] = processed["input_ids"][1:]
238
+ processed["attention_mask"] = processed["attention_mask"][1:]
239
+ processed["labels"] = processed["labels"][1:]
240
+
241
+ all_processed.append(processed)
242
+ all_texts += item["text"]
243
+
244
+ # Handle example indices
245
+ this_num = -1
246
+ if 'example_ind' in item.keys():
247
+ if item["example_ind"] is not None:
248
+ this_num = item["example_ind"]
249
+ example_inds.extend([this_num] * len(processed["input_ids"]))
250
+
251
+ # Handle dataset indices
252
+ dataset_ind = -1
253
+ if "data_id" in item.keys():
254
+ if item["data_id"] is not None:
255
+ dataset_ind = item["data_id"]
256
+ dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
257
+
258
+ # Combine all processed results
259
+ processed = all_processed[0].copy()
260
+ processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
261
+ processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
262
+ processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
263
+ processed["example_inds"] = example_inds
264
+ processed["data_ids"] = dataset_inds
265
+
266
+ # Validate by tokenizing all_texts at once and comparing
267
+ processed_all = self(text=all_texts)
268
+ if len(processed_all["input_ids"]) != len(processed["input_ids"]):
269
+ warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
270
+
271
+ # Generate diff for debugging
272
+ all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
273
+ processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
274
+
275
+ diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
276
+ diff_str = "\n".join(diff)
277
+ print("Diff between texts:")
278
+ print(diff_str)
279
+
280
+ # Token diff
281
+ all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
282
+ processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
283
+ token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
284
+ token_diff_str = "\n".join(token_diff)
285
+ print("Diff between tokenized texts:")
286
+ print(token_diff_str)
287
+
288
+ # Add EOS token if needed
289
+ if (self.eos_token_id is not None and
290
+ processed["input_ids"][-1] != self.eos_token_id):
291
+ processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
292
+ processed["example_inds"] = processed["example_inds"] + [-1]
293
+ processed["attention_mask"] = processed["attention_mask"] + [1]
294
+ if processed["labels"] is not None:
295
+ if loss_on_eos:
296
+ processed["labels"] = processed["labels"] + [self.eos_token_id]
297
+ else:
298
+ processed["labels"] = processed["labels"] + [-100]
299
+ if "data_ids" in processed:
300
+ processed["data_ids"] = processed["data_ids"] + [-1]
301
+
302
+ if not include_eos:
303
+ # check if EOS token is present
304
+ if processed["input_ids"][-1] == self.eos_token_id:
305
+ # remove EOS token
306
+ processed["input_ids"] = processed["input_ids"][:-1]
307
+ processed["attention_mask"] = processed["attention_mask"][:-1]
308
+ processed["labels"] = processed["labels"][:-1]
309
+ processed["example_inds"] = processed["example_inds"][:-1]
310
+ processed["data_ids"] = processed["data_ids"][:-1]
311
+
312
+ return processed
313
+
314
+ def tokenize_messages(
315
+ self,
316
+ messages: List[Dict[str, Any]],
317
+ loss_on_start_token: bool = False,
318
+ loss_on_eos: bool = False,
319
+ include_eos: bool = True,
320
+ ) -> Dict[str, Any]:
321
+ """
322
+ Intended for tokenize from messages to tokenized texts with the loss applied.
323
+ """
324
+ # First convert messages to text with loss computation flags
325
+ texts = self.messages_to_loss_texts(messages, loss_on_start_token)
326
+
327
+ # Then tokenize the texts
328
+ return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos)
329
+
330
+
331
+
332
+
333
+ # Register tokenizer classes for AutoTokenizer
334
+ AutoTokenizer.register("GemmaExplicitTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaExplicitTokenizer)
335
+
336
+
337
+ if __name__ == "__main__":
338
+ # Example usage
339
+ # for first load
340
+ custom_tokenizer = GemmaExplicitTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
341
+
342
+ # for subsequent loads
343
+ # custom_tokenizer = GemmaExplicitTokenizer.from_pretrained("tsor13/explicit-gemma-12b-pt")
344
+ # custom_tokenizer = GemmaExplicitTokenizer.from_pretrained("repos/explicit-gemma-12b-pt")
345
+
346
+ # Test messages in role/content format
347
+ messages = [
348
+ {"role": "description", "content": "This is a test task"},
349
+ {"role": "input", "content": "What is 2+2?"},
350
+ {"role": "output", "content": "4"},
351
+ {"role": "input", "content": "What is 3+3?"},
352
+ # {"role": "output", "content": "6"}
353
+ ]
354
+
355
+ # get messages to text_loss
356
+ texts = custom_tokenizer.messages_to_loss_texts(messages)
357
+ print("Texts with loss flags:")
358
+ for i, text in enumerate(texts):
359
+ print(f" {i}: {text}")
360
+
361
+ text = custom_tokenizer.messages_to_text(messages, start_generation=True)
362
+ print(f"\nFull text with generation prompt:")
363
+ print(text)
364
+
365
+ print("\nTesting save/load cycle:")
366
+ # Test saving and loading
367
+ tokenizer_path = "repos/explicit-gemma-tokenizer"
368
+ custom_tokenizer.save_pretrained(tokenizer_path)
369
+ print("Tokenizer saved successfully!")
370
+
371
+ # also save this file in the tokenizer_path
372
+ import shutil
373
+ shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_explicit_tokenizer.py"))
374
+ print("GemmaExplicitTokenizer.py saved successfully!")
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 2,
3
+ "cache_implementation": "hybrid",
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 1,
7
+ 106
8
+ ],
9
+ "pad_token_id": 0,
10
+ "top_k": 64,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.3"
13
+ }
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:542b2874d19bff3762d350e9c1bd370e50d80c5f709b80f1c7a443fc01849b20
3
+ size 4979902192
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ffcce1c51995b5dc648d237497030a2f779a4cd33761be3737393254886079
3
+ size 4931296592
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:092ae6527089321ed068b3103a101b768b3953bd4c43140bd82bdc423d0795cb
3
+ size 4931296656
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:422970ba53b0cb3fc42ddd8dc27077b25ee953c4d64df6861b35ee92a4529d6d
3
+ size 4931296656
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71868eb31768e9d0d97fe9207f4abb09a9332e565552cfc762a5b9d0dd1b905f
3
+ size 4601000928
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b89185f26b9cbdf0d6a835b8d4fb2bf8c6584347553dcaa366582ca993a82c
3
+ size 7377