Initial upload of fine‑tuned Gemma + custom tokenizer
Browse files- .gitattributes +1 -0
- README.md +221 -0
- added_tokens.json +3 -0
- config.json +62 -0
- gemma_chat_tokenizer.py +448 -0
- generation_config.json +13 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- special_tokens_map.json +33 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
- training_args.bin +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### tsor13/Special12b
|
| 2 |
+
The following is a a model trained by [...suspense...] that is meant to:
|
| 3 |
+
- follow instructions better than pretrained models and be more diverse / less mode-collapsed than instruct models;
|
| 4 |
+
- be a really good, approximately bayesian in-context learner;
|
| 5 |
+
- fit an data generation process
|
| 6 |
+
- be calibrated over distributions of possible outputs wrt a population or epistemic uncertainty
|
| 7 |
+
It is initialized from `google/gemma-3-12b-pt`.
|
| 8 |
+
|
| 9 |
+
This model/repo is a work in progress - expect updates.
|
| 10 |
+
|
| 11 |
+
Loading model example:
|
| 12 |
+
```
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 14 |
+
tokenizer = AutoTokenizer.from_pretrained("tsor13/special12b", trust_remote_code=True) # custom tokenizer for handling messages / loss
|
| 15 |
+
model = AutoModelForCausalLM.from_pretrained("tsor13/special12b", device_map="auto")
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
It has its own chat-style input messages, with the following roles:
|
| 19 |
+
- `description`(optional): A description of the generating process, or some information meant to instantiate a prior
|
| 20 |
+
- `input` (optional): Any variables that a model is not responsible for predicting, but could be used to condition generation somehow;
|
| 21 |
+
- `output`: This is what the model will actually predict / generate.
|
| 22 |
+
|
| 23 |
+
For example,
|
| 24 |
+
```
|
| 25 |
+
messages = [
|
| 26 |
+
{"role": "description", "content": "Capitals"},
|
| 27 |
+
{"role": "input", "content": "France"},
|
| 28 |
+
{"role": "output", "content": "Paris"},
|
| 29 |
+
{"role": "input", "content": "Japan"},
|
| 30 |
+
]
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
To templatize the messages, you can use the tokenizer:
|
| 34 |
+
```
|
| 35 |
+
formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
|
| 36 |
+
print(formatted_prompt) # start_generation adds the <start_of_turn> token to condition the model for generation
|
| 37 |
+
```
|
| 38 |
+
Output:
|
| 39 |
+
```
|
| 40 |
+
Capitals
|
| 41 |
+
France
|
| 42 |
+
<start_of_turn>Paris<end_of_turn>
|
| 43 |
+
Japan
|
| 44 |
+
<start_of_turn>
|
| 45 |
+
```
|
| 46 |
+
The data for the model to emulate / generate is wrapped in `<start_of_turn>` / `<end_of_turn>` tokens.
|
| 47 |
+
Description and input is not wrapped in anything. Thus, do not expect the model to generate these tokens - instead focus on the wrapped output tokens.
|
| 48 |
+
Messages are separated by newlines.
|
| 49 |
+
|
| 50 |
+
In training, loss is ONLY calculated on the output tokens and the `<end_of_turn>` token. Thus, the model is only designed to generate / predict probabilities after `<start_of_turn>` and until `<end_of_turn>` - everything else is out of distribution for the model and not recommended.
|
| 51 |
+
|
| 52 |
+
Once you have the formatted text, you can tokenize as normal:
|
| 53 |
+
```
|
| 54 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Let's look at what the model does. In this case, there is a single correct answer. Let's look at model probabilities after `<start_of_turn>`:
|
| 58 |
+
```
|
| 59 |
+
import torch
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
output = model(**inputs)
|
| 62 |
+
logits = output.logits[0, -1, :]
|
| 63 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 64 |
+
top_probs, top_indices = torch.topk(probs, 10)
|
| 65 |
+
print("\nTop 10 probabilities for first output token:")
|
| 66 |
+
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
|
| 67 |
+
token = tokenizer.decode(idx)
|
| 68 |
+
print(f"{i+1:2d}. '{token}' -> {prob.item():.4f}")
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Output:
|
| 72 |
+
```
|
| 73 |
+
Top 10 probabilities for first output token:
|
| 74 |
+
1. 'Tokyo' -> 0.9764
|
| 75 |
+
2. 'Tok' -> 0.0070
|
| 76 |
+
3. '東京' -> 0.0026
|
| 77 |
+
4. 'Ky' -> 0.0019
|
| 78 |
+
5. 'T' -> 0.0014
|
| 79 |
+
6. ' Tokyo' -> 0.0014
|
| 80 |
+
7. 'To' -> 0.0011
|
| 81 |
+
8. 'Osaka' -> 0.0009
|
| 82 |
+
9. 'Toy' -> 0.0007
|
| 83 |
+
10. 'tok' -> 0.0005
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
Great! Almost all of the probability mass is on the correct answer, Tokyo.
|
| 87 |
+
|
| 88 |
+
Let's try an example with many possible reasonable choices / a harder to describe distribution. For example, say that I'm interested in modeling "board games that I like". I may be hard-pressed to actually describe what it is that I like about games - but I could provide a few examples pretty easily.
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
messages = [
|
| 92 |
+
{"role": "output", "content": "Dune: Imperium"},
|
| 93 |
+
{"role": "output", "content": "Acquire"},
|
| 94 |
+
{"role": "output", "content": "Catan"},
|
| 95 |
+
{"role": "output", "content": "Tigris and Euphrates"},
|
| 96 |
+
{"role": "output", "content": "Brass: Birmingham"},
|
| 97 |
+
]
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
Given these example outputs, the model will try to generate more outputs like these outputs.
|
| 101 |
+
```
|
| 102 |
+
formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
|
| 103 |
+
n_gens = 4
|
| 104 |
+
inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
|
| 105 |
+
|
| 106 |
+
outputs = model.generate(**inputs, max_new_tokens=10, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
|
| 107 |
+
for i in range(n_gens):
|
| 108 |
+
print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Outputs:
|
| 112 |
+
```
|
| 113 |
+
Catan: Rivals for Catan
|
| 114 |
+
Gloomhaven
|
| 115 |
+
Great Western Trail
|
| 116 |
+
Azul
|
| 117 |
+
```
|
| 118 |
+
Not too bad!
|
| 119 |
+
|
| 120 |
+
You can also specify just the description:
|
| 121 |
+
Input:
|
| 122 |
+
```
|
| 123 |
+
messages = [
|
| 124 |
+
{"role": "description", "content": "Descriptive colors"},
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
|
| 128 |
+
n_gens = 4
|
| 129 |
+
inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
|
| 130 |
+
|
| 131 |
+
outputs = model.generate(**inputs, max_new_tokens=10, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
|
| 132 |
+
for i in range(n_gens):
|
| 133 |
+
print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
|
| 134 |
+
print()
|
| 135 |
+
```
|
| 136 |
+
Output:
|
| 137 |
+
```
|
| 138 |
+
Deep Sea Blue
|
| 139 |
+
Gray#222222
|
| 140 |
+
Gold, Red, Black
|
| 141 |
+
I can’t believe we’re already talking about color theory. How is this possible? Can time go any faster? Also how does your body
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
By default, the model is only trained to do 1) either emulate outputs if examples are provided, or 2) generate data based on the description. Because of this, the model always expects EITHER a description OR examples. If you want it to act slightly more like an instruction following chat model, you can add a description such as the following:
|
| 145 |
+
|
| 146 |
+
```
|
| 147 |
+
messages = [
|
| 148 |
+
{"role": "description", "content": "You are a helpful assistant who outputs the requested content."},
|
| 149 |
+
{"role": "input", "content": "A poem about a shark"},
|
| 150 |
+
]
|
| 151 |
+
```
|
| 152 |
+
To generate:
|
| 153 |
+
```
|
| 154 |
+
formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
|
| 155 |
+
n_gens = 4
|
| 156 |
+
inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
|
| 157 |
+
|
| 158 |
+
outputs = model.generate(**inputs, max_new_tokens=40, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
|
| 159 |
+
for i in range(n_gens):
|
| 160 |
+
print(f"Generation {i}:")
|
| 161 |
+
print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
Some example generations:
|
| 165 |
+
```
|
| 166 |
+
Generation 0:
|
| 167 |
+
A deep-sea creature, silent and fierce, Shivers through water, its body sleek. Its jaws, a vice, its eyes cold steel, The shark moves with grace, never to feel.
|
| 168 |
+
of power and danger,
|
| 169 |
+
Generation 1:
|
| 170 |
+
The great white shark lurks in the deep, with teeth so sharp, it could cut a whale in half. Its dorsal fin slices through the water, like a knife through butter, and its tail
|
| 171 |
+
Generation 2:
|
| 172 |
+
The shark swam in the sea, With a toothy grin, as if it could be glee. It was the top of the food chain, The apex of the sea's terrain. With sleek
|
| 173 |
+
Generation 3:
|
| 174 |
+
I am a gentle, tranquil wave, gliding smoothly across the ocean's expanse. Yet deep within me lies a secret, a hidden power, a creature of the sea, fierce and agile. It
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
Finally, let's look at a synthetic data generation task. For example, maybe we want to generate situations to do social reasoning over, along with whether or not they are awkward. When there are multiple variables to condition on or generat, the model is used to json format.
|
| 181 |
+
|
| 182 |
+
Input:
|
| 183 |
+
```
|
| 184 |
+
import json
|
| 185 |
+
messages = [
|
| 186 |
+
{"role": "description", "content": "Situations to do social reasoning over, along with whether or not it is an awkward situation."},
|
| 187 |
+
{"role": "output", "content": json.dumps({
|
| 188 |
+
"situation": "You're at a party and you realize that your shirt is on backwards.",
|
| 189 |
+
"is_awkward": True,
|
| 190 |
+
})},
|
| 191 |
+
{"role": "output", "content": json.dumps({
|
| 192 |
+
"situation": "While at work, your boss commends you on a job well done.",
|
| 193 |
+
"is_awkward": False,
|
| 194 |
+
})},
|
| 195 |
+
{"role": "output", "content": json.dumps({
|
| 196 |
+
"situation": "Realizing you forgot to bring your passport to the airport.",
|
| 197 |
+
"is_awkward": True,
|
| 198 |
+
})},
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
formatted_prompt = tokenizer.messages_to_text(messages, start_generation=True)
|
| 202 |
+
n_gens = 4
|
| 203 |
+
inputs = tokenizer([formatted_prompt] * n_gens, return_tensors="pt").to(model.device)
|
| 204 |
+
|
| 205 |
+
outputs = model.generate(**inputs, max_new_tokens=40, stop_strings=["<end_of_turn>"], tokenizer=tokenizer)
|
| 206 |
+
for i in range(n_gens):
|
| 207 |
+
print(tokenizer.decode(outputs[i][inputs["input_ids"][i].shape[0]:], skip_special_tokens=True))
|
| 208 |
+
```
|
| 209 |
+
Output:
|
| 210 |
+
```
|
| 211 |
+
{"situation": "While walking on the street, someone waves and smiles at you, but you don't know them.", "is_awkward": false}
|
| 212 |
+
{"situation": "Taking a cab and giving the driver wrong directions.", "is_awkward": true}
|
| 213 |
+
{"situation": "Being told that an individual you've had a long-term crush on is also crushing on someone else.", "is_awkward": true}
|
| 214 |
+
{"situation": "Watching a loved one get proposed to.", "is_awkward": false}
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
A few tips and tricks:
|
| 218 |
+
- Do not expect the model to do multi-turn chats. It is designed to be stateless and to treat each data point as "exchangeable" (roughly iid).
|
| 219 |
+
- If all you want is one reasonable answer, then a chat model is likely a better fit. However, if you want to generate many reasonable answers / diverse examples, this model is a better fit.
|
| 220 |
+
- The model is quite good at perspective taking / steering if you provide many examples.
|
| 221 |
+
- The model is reasonably good at expressing epistemic uncertainty over unsure outputs by sampling several times.
|
added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image_soft_token>": 262144
|
| 3 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Gemma3ForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"boi_token_index": 255999,
|
| 6 |
+
"eoi_token_index": 256000,
|
| 7 |
+
"eos_token_id": [
|
| 8 |
+
1,
|
| 9 |
+
106
|
| 10 |
+
],
|
| 11 |
+
"image_token_index": 262144,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mm_tokens_per_image": 256,
|
| 14 |
+
"model_type": "gemma3",
|
| 15 |
+
"text_config": {
|
| 16 |
+
"attention_bias": false,
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"attn_logit_softcapping": null,
|
| 19 |
+
"cache_implementation": "hybrid",
|
| 20 |
+
"final_logit_softcapping": null,
|
| 21 |
+
"head_dim": 256,
|
| 22 |
+
"hidden_activation": "gelu_pytorch_tanh",
|
| 23 |
+
"hidden_size": 3840,
|
| 24 |
+
"initializer_range": 0.02,
|
| 25 |
+
"intermediate_size": 15360,
|
| 26 |
+
"max_position_embeddings": 131072,
|
| 27 |
+
"model_type": "gemma3_text",
|
| 28 |
+
"num_attention_heads": 16,
|
| 29 |
+
"num_hidden_layers": 48,
|
| 30 |
+
"num_key_value_heads": 8,
|
| 31 |
+
"query_pre_attn_scalar": 256,
|
| 32 |
+
"rms_norm_eps": 1e-06,
|
| 33 |
+
"rope_local_base_freq": 10000.0,
|
| 34 |
+
"rope_scaling": {
|
| 35 |
+
"factor": 8.0,
|
| 36 |
+
"rope_type": "linear"
|
| 37 |
+
},
|
| 38 |
+
"rope_theta": 1000000.0,
|
| 39 |
+
"sliding_window": 1024,
|
| 40 |
+
"sliding_window_pattern": 6,
|
| 41 |
+
"torch_dtype": "float32",
|
| 42 |
+
"use_cache": true,
|
| 43 |
+
"vocab_size": 262208
|
| 44 |
+
},
|
| 45 |
+
"torch_dtype": "bfloat16",
|
| 46 |
+
"transformers_version": "4.51.3",
|
| 47 |
+
"vision_config": {
|
| 48 |
+
"attention_dropout": 0.0,
|
| 49 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 50 |
+
"hidden_size": 1152,
|
| 51 |
+
"image_size": 896,
|
| 52 |
+
"intermediate_size": 4304,
|
| 53 |
+
"layer_norm_eps": 1e-06,
|
| 54 |
+
"model_type": "siglip_vision_model",
|
| 55 |
+
"num_attention_heads": 16,
|
| 56 |
+
"num_channels": 3,
|
| 57 |
+
"num_hidden_layers": 27,
|
| 58 |
+
"patch_size": 14,
|
| 59 |
+
"torch_dtype": "float32",
|
| 60 |
+
"vision_use_head": false
|
| 61 |
+
}
|
| 62 |
+
}
|
gemma_chat_tokenizer.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Gemma Tokenizer for chat Format
|
| 3 |
+
|
| 4 |
+
This tokenizer implements the chat format for message processing:
|
| 5 |
+
Format: Uses the standard chat template with proper role labels (user/assistant)
|
| 6 |
+
|
| 7 |
+
The chat format uses the model's built-in chat template and includes proper
|
| 8 |
+
loss computation flags for training with "assistant" as the generation role.
|
| 9 |
+
|
| 10 |
+
To save:
|
| 11 |
+
uv run tokenizers/gemma_chat_tokenizer.py
|
| 12 |
+
which will save the tokenizer to the repos/chat-gemma-tokenizer directory.
|
| 13 |
+
mkdir repos/chat12b
|
| 14 |
+
# copy model over
|
| 15 |
+
cp models_v8/base_modified-google-gemma-3-12b-pt-/models/_chat/checkpoint-8/* repos/chat12b/
|
| 16 |
+
# copy tokenizer over
|
| 17 |
+
cp repos/chat-gemma-tokenizer/* repos/chat12b/
|
| 18 |
+
# upload to hf
|
| 19 |
+
|
| 20 |
+
uv run upload_to_hf.py \
|
| 21 |
+
--folder repos/chat12b \
|
| 22 |
+
--repo-id tsor13/chat12b
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from typing import List, Dict, Any, Optional, Union
|
| 26 |
+
from transformers import AutoTokenizer
|
| 27 |
+
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
|
| 28 |
+
from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
|
| 29 |
+
import warnings
|
| 30 |
+
import difflib
|
| 31 |
+
import json
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GemmaChatTokenizer(GemmaTokenizerFast):
|
| 37 |
+
"""
|
| 38 |
+
Custom tokenizer for Gemma models that implements chat format message processing.
|
| 39 |
+
|
| 40 |
+
This tokenizer formats messages using the chat format where:
|
| 41 |
+
- Messages use the standard chat template with proper role labels
|
| 42 |
+
- Uses the model's built-in chat formatting
|
| 43 |
+
- Loss is computed on the assistant sections (not output)
|
| 44 |
+
|
| 45 |
+
Attributes:
|
| 46 |
+
start_string (str): The starting string used for output generation (depends on tokenizer)
|
| 47 |
+
end_string (str): The ending string used for output generation (depends on tokenizer)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
"""
|
| 52 |
+
Initialize the custom tokenizer.
|
| 53 |
+
|
| 54 |
+
Accepts the same arguments as GemmaTokenizerFast.
|
| 55 |
+
"""
|
| 56 |
+
super().__init__(*args, **kwargs)
|
| 57 |
+
|
| 58 |
+
# For chat format, we use the tokenizer's own chat template
|
| 59 |
+
# The start/end strings will be determined by the chat template
|
| 60 |
+
self.start_string = "<start_of_turn>" # Will be set dynamically
|
| 61 |
+
self.end_string = "<end_of_turn>" # Will be set dynamically
|
| 62 |
+
|
| 63 |
+
# Add custom attributes to the tokenizer config for saving/loading
|
| 64 |
+
if not hasattr(self, 'init_kwargs'):
|
| 65 |
+
self.init_kwargs = {}
|
| 66 |
+
self.init_kwargs['start_string'] = self.start_string
|
| 67 |
+
self.init_kwargs['end_string'] = self.end_string
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_gemma_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 71 |
+
"""
|
| 72 |
+
Load a tokenizer from a pretrained model or path.
|
| 73 |
+
|
| 74 |
+
This method ensures our custom class is used instead of the base GemmaTokenizerFast.
|
| 75 |
+
"""
|
| 76 |
+
# Load the base tokenizer first to get all configuration
|
| 77 |
+
base_tokenizer = GemmaTokenizerFast.from_pretrained(
|
| 78 |
+
pretrained_model_name_or_path, *args, **kwargs
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Create new instance of our custom class by copying the base tokenizer
|
| 82 |
+
custom_tokenizer = cls.__new__(cls)
|
| 83 |
+
|
| 84 |
+
# Copy all attributes from base tokenizer
|
| 85 |
+
for attr, value in base_tokenizer.__dict__.items():
|
| 86 |
+
setattr(custom_tokenizer, attr, value)
|
| 87 |
+
|
| 88 |
+
# Initialize our custom attributes for chat format
|
| 89 |
+
custom_tokenizer.start_string = "<start_of_turn>"
|
| 90 |
+
custom_tokenizer.end_string = "<end_of_turn>"
|
| 91 |
+
|
| 92 |
+
# Update init_kwargs to include our custom attributes
|
| 93 |
+
if not hasattr(custom_tokenizer, 'init_kwargs'):
|
| 94 |
+
custom_tokenizer.init_kwargs = {}
|
| 95 |
+
custom_tokenizer.init_kwargs['start_string'] = custom_tokenizer.start_string
|
| 96 |
+
custom_tokenizer.init_kwargs['end_string'] = custom_tokenizer.end_string
|
| 97 |
+
|
| 98 |
+
return custom_tokenizer
|
| 99 |
+
|
| 100 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 101 |
+
"""
|
| 102 |
+
Save the tokenizer to a directory, including custom configuration.
|
| 103 |
+
"""
|
| 104 |
+
# Call parent save method
|
| 105 |
+
super().save_pretrained(save_directory, **kwargs)
|
| 106 |
+
|
| 107 |
+
# Save custom configuration
|
| 108 |
+
config_file = os.path.join(save_directory, "tokenizer_config.json")
|
| 109 |
+
if os.path.exists(config_file):
|
| 110 |
+
with open(config_file, 'r') as f:
|
| 111 |
+
config = json.load(f)
|
| 112 |
+
else:
|
| 113 |
+
config = {}
|
| 114 |
+
|
| 115 |
+
# Add our custom class info
|
| 116 |
+
config["tokenizer_class"] = "GemmaChatTokenizer"
|
| 117 |
+
config["start_string"] = self.start_string
|
| 118 |
+
config["end_string"] = self.end_string
|
| 119 |
+
# Point to our custom class in the uploaded file
|
| 120 |
+
config["auto_map"] = {
|
| 121 |
+
"AutoTokenizer": ["gemma_chat_tokenizer.GemmaChatTokenizer", "gemma_chat_tokenizer.GemmaChatTokenizer"]
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
with open(config_file, 'w') as f:
|
| 125 |
+
json.dump(config, f, indent=2)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def messages_to_loss_texts(
|
| 129 |
+
self,
|
| 130 |
+
messages: List[Dict[str, Any]],
|
| 131 |
+
loss_on_start_token: bool = False,
|
| 132 |
+
default_user_message: str = "Generate.",
|
| 133 |
+
start_generation: bool = False,
|
| 134 |
+
) -> List[Dict[str, Any]]:
|
| 135 |
+
"""
|
| 136 |
+
From messages (description / input / output) to texts (text / compute_loss) with whether or not loss should be calculated on the text for training.
|
| 137 |
+
Uses the chat format matching chat_utils.py.
|
| 138 |
+
"""
|
| 139 |
+
# FOR NOW, OVERRIDING TO FALSE
|
| 140 |
+
loss_on_start_token = False
|
| 141 |
+
|
| 142 |
+
texts = []
|
| 143 |
+
chat_messages = []
|
| 144 |
+
has_input = False
|
| 145 |
+
has_description = False
|
| 146 |
+
|
| 147 |
+
# Convert to chat format
|
| 148 |
+
for message in messages:
|
| 149 |
+
if message["role"] == "description":
|
| 150 |
+
chat_messages.append({"role": "system", "content": "Generate something that fits this description. Don't generate anything else, just the desired generation output.\nDescription: " + message["content"]})
|
| 151 |
+
has_description = True
|
| 152 |
+
elif message["role"] == "input":
|
| 153 |
+
has_input = True
|
| 154 |
+
chat_messages.append({"role": "user", "content": message["content"]})
|
| 155 |
+
elif message["role"] == "output":
|
| 156 |
+
if not has_input:
|
| 157 |
+
chat_messages.append({"role": "user", "content": default_user_message})
|
| 158 |
+
chat_messages.append({"role": "assistant", "content": message["content"]})
|
| 159 |
+
# if last message is output and start_generation is true, add a default user message
|
| 160 |
+
if start_generation and chat_messages[-1]["role"] == "assistant":
|
| 161 |
+
chat_messages.append({"role": "user", "content": default_user_message})
|
| 162 |
+
|
| 163 |
+
# Apply chat template
|
| 164 |
+
full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation)
|
| 165 |
+
# replace <bos> with nothing
|
| 166 |
+
full_text = full_text.replace("<bos>", "")
|
| 167 |
+
|
| 168 |
+
text_to_split = full_text
|
| 169 |
+
# now, find all places starting with <start_of_turn>model\n
|
| 170 |
+
model_start_text = "<start_of_turn>model\n" # TODO - manual for now, change later
|
| 171 |
+
first = True
|
| 172 |
+
while model_start_text in text_to_split:
|
| 173 |
+
# get location of model_start_text
|
| 174 |
+
model_start_loc = text_to_split.find(model_start_text)
|
| 175 |
+
split_ind = model_start_loc + len(model_start_text)
|
| 176 |
+
text_to_add, text_to_split = text_to_split[:split_ind], text_to_split[split_ind:]
|
| 177 |
+
# add to texts
|
| 178 |
+
texts.append({"text": text_to_add, "compute_loss": False})
|
| 179 |
+
# get location of end_string
|
| 180 |
+
end_string_loc = text_to_split.find(self.end_string)
|
| 181 |
+
end_ind = end_string_loc + len(self.end_string)
|
| 182 |
+
text_to_add, text_to_split = text_to_split[:end_ind], text_to_split[end_ind:]
|
| 183 |
+
if first and not has_description:
|
| 184 |
+
texts.append({"text": text_to_add, "compute_loss": False})
|
| 185 |
+
else:
|
| 186 |
+
texts.append({"text": text_to_add, "compute_loss": True})
|
| 187 |
+
first = False
|
| 188 |
+
if len(text_to_split) > 0:
|
| 189 |
+
texts.append({"text": text_to_split, "compute_loss": False})
|
| 190 |
+
|
| 191 |
+
return texts
|
| 192 |
+
|
| 193 |
+
def messages_to_text(
|
| 194 |
+
self,
|
| 195 |
+
messages: List[Dict[str, Any]],
|
| 196 |
+
start_generation: bool = False,
|
| 197 |
+
) -> str:
|
| 198 |
+
"""
|
| 199 |
+
Messages (description / input / output) to raw text (text).
|
| 200 |
+
Uses the chat format matching chat_utils.py.
|
| 201 |
+
"""
|
| 202 |
+
texts = self.messages_to_loss_texts(messages, start_generation=start_generation)
|
| 203 |
+
text = "".join([text["text"] for text in texts])
|
| 204 |
+
return text
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def tokenize_messages(
|
| 208 |
+
self,
|
| 209 |
+
messages: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
|
| 210 |
+
start_generation: bool = False,
|
| 211 |
+
**kwargs,
|
| 212 |
+
):
|
| 213 |
+
"""
|
| 214 |
+
For tokenizing from messages to texts. Supports batching. Good for generation
|
| 215 |
+
"""
|
| 216 |
+
if isinstance(messages, list) and isinstance(messages[0], list):
|
| 217 |
+
# Handle list of lists of messages
|
| 218 |
+
all_texts = []
|
| 219 |
+
for message_list in messages:
|
| 220 |
+
texts = self.messages_to_text(message_list, start_generation)
|
| 221 |
+
all_texts.append(texts)
|
| 222 |
+
else:
|
| 223 |
+
# Handle single list of messages
|
| 224 |
+
texts = self.messages_to_text(messages, start_generation)
|
| 225 |
+
all_texts = [texts]
|
| 226 |
+
|
| 227 |
+
# Tokenize all texts
|
| 228 |
+
processed = self(text=all_texts, **kwargs)
|
| 229 |
+
return processed
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def tokenize_loss_texts(
|
| 233 |
+
self,
|
| 234 |
+
texts: List[Dict[str, Any]],
|
| 235 |
+
loss_on_start_token: bool = False,
|
| 236 |
+
loss_on_eos: bool = False,
|
| 237 |
+
include_eos: bool = True,
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Tokenize texts (text / compute_loss) to tokenized texts (input_ids / attention_mask / labels).
|
| 241 |
+
|
| 242 |
+
Needs more complex logic to handle the back and forth labeling.
|
| 243 |
+
"""
|
| 244 |
+
if loss_on_eos:
|
| 245 |
+
raise ValueError("Loss on EOS is not currently supported.")
|
| 246 |
+
|
| 247 |
+
# Handle single string input
|
| 248 |
+
if isinstance(texts, str):
|
| 249 |
+
processed = self(text=texts)
|
| 250 |
+
# Add EOS token if needed
|
| 251 |
+
if (self.eos_token_id is not None and
|
| 252 |
+
processed["input_ids"][-1] != self.eos_token_id):
|
| 253 |
+
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
|
| 254 |
+
processed["attention_mask"] = processed["attention_mask"] + [1]
|
| 255 |
+
return processed
|
| 256 |
+
|
| 257 |
+
# Handle list of text dictionaries
|
| 258 |
+
all_processed = []
|
| 259 |
+
all_texts = ''
|
| 260 |
+
example_inds = []
|
| 261 |
+
dataset_inds = []
|
| 262 |
+
|
| 263 |
+
for i, item in enumerate(texts):
|
| 264 |
+
processed = self(text=item["text"])
|
| 265 |
+
|
| 266 |
+
# Remove BOS token from all but first item
|
| 267 |
+
if i != 0 and self.bos_token_id == processed["input_ids"][0]:
|
| 268 |
+
processed["input_ids"] = processed["input_ids"][1:]
|
| 269 |
+
processed["attention_mask"] = processed["attention_mask"][1:]
|
| 270 |
+
|
| 271 |
+
# Remove EOS token if present at the end
|
| 272 |
+
if processed["input_ids"][-1] == self.eos_token_id:
|
| 273 |
+
processed["input_ids"] = processed["input_ids"][:-1]
|
| 274 |
+
processed["attention_mask"] = processed["attention_mask"][:-1]
|
| 275 |
+
|
| 276 |
+
# Check for EOS token in the middle (with special handling for <|im_end|>)
|
| 277 |
+
if self.eos_token_id in processed["input_ids"]:
|
| 278 |
+
if not self.decode([self.eos_token_id]) == "<|im_end|>":
|
| 279 |
+
raise ValueError(f"EOS token is present in input_ids: {processed['input_ids']}. Not currently supported.")
|
| 280 |
+
|
| 281 |
+
# Set labels based on compute_loss flag
|
| 282 |
+
if item["compute_loss"]:
|
| 283 |
+
processed["labels"] = processed["input_ids"].copy()
|
| 284 |
+
else:
|
| 285 |
+
processed["labels"] = [-100] * len(processed["input_ids"])
|
| 286 |
+
|
| 287 |
+
# Remove duplicate BOS tokens
|
| 288 |
+
if all_processed:
|
| 289 |
+
if processed["input_ids"][0] == self.bos_token_id:
|
| 290 |
+
processed["input_ids"] = processed["input_ids"][1:]
|
| 291 |
+
processed["attention_mask"] = processed["attention_mask"][1:]
|
| 292 |
+
processed["labels"] = processed["labels"][1:]
|
| 293 |
+
|
| 294 |
+
all_processed.append(processed)
|
| 295 |
+
all_texts += item["text"]
|
| 296 |
+
|
| 297 |
+
# Handle example indices
|
| 298 |
+
this_num = -1
|
| 299 |
+
if 'example_ind' in item.keys():
|
| 300 |
+
if item["example_ind"] is not None:
|
| 301 |
+
this_num = item["example_ind"]
|
| 302 |
+
example_inds.extend([this_num] * len(processed["input_ids"]))
|
| 303 |
+
|
| 304 |
+
# Handle dataset indices
|
| 305 |
+
dataset_ind = -1
|
| 306 |
+
if "data_id" in item.keys():
|
| 307 |
+
if item["data_id"] is not None:
|
| 308 |
+
dataset_ind = item["data_id"]
|
| 309 |
+
dataset_inds.extend([dataset_ind] * len(processed["input_ids"]))
|
| 310 |
+
|
| 311 |
+
# Combine all processed results
|
| 312 |
+
processed = all_processed[0].copy()
|
| 313 |
+
processed["input_ids"] = [item for sublist in [p["input_ids"] for p in all_processed] for item in sublist]
|
| 314 |
+
processed["attention_mask"] = [item for sublist in [p["attention_mask"] for p in all_processed] for item in sublist]
|
| 315 |
+
processed["labels"] = [item for sublist in [p["labels"] for p in all_processed] for item in sublist]
|
| 316 |
+
processed["example_inds"] = example_inds
|
| 317 |
+
processed["data_ids"] = dataset_inds
|
| 318 |
+
|
| 319 |
+
# Validate by tokenizing all_texts at once and comparing
|
| 320 |
+
processed_all = self(text=all_texts)
|
| 321 |
+
if len(processed_all["input_ids"]) != len(processed["input_ids"]):
|
| 322 |
+
warnings.warn(f"All texts are not the same length as the first text. Please check your dataset. {len(processed_all['input_ids'])} != {len(processed['input_ids'])}")
|
| 323 |
+
|
| 324 |
+
# Generate diff for debugging
|
| 325 |
+
all_text = self.decode(processed_all["input_ids"], skip_special_tokens=False)
|
| 326 |
+
processed_text = self.decode(processed["input_ids"], skip_special_tokens=False)
|
| 327 |
+
|
| 328 |
+
diff = difflib.unified_diff(all_text.splitlines(), processed_text.splitlines())
|
| 329 |
+
diff_str = "\n".join(diff)
|
| 330 |
+
print("Diff between texts:")
|
| 331 |
+
print(diff_str)
|
| 332 |
+
|
| 333 |
+
# Token diff
|
| 334 |
+
all_tokens_str = '\n'.join([str(s) for s in processed_all["input_ids"]])
|
| 335 |
+
processed_tokens_str = '\n'.join([str(s) for s in processed["input_ids"]])
|
| 336 |
+
token_diff = difflib.unified_diff(all_tokens_str.splitlines(), processed_tokens_str.splitlines())
|
| 337 |
+
token_diff_str = "\n".join(token_diff)
|
| 338 |
+
print("Diff between tokenized texts:")
|
| 339 |
+
print(token_diff_str)
|
| 340 |
+
|
| 341 |
+
# Add EOS token if needed
|
| 342 |
+
if (self.eos_token_id is not None and
|
| 343 |
+
processed["input_ids"][-1] != self.eos_token_id):
|
| 344 |
+
processed["input_ids"] = processed["input_ids"] + [self.eos_token_id]
|
| 345 |
+
processed["example_inds"] = processed["example_inds"] + [-1]
|
| 346 |
+
processed["attention_mask"] = processed["attention_mask"] + [1]
|
| 347 |
+
if processed["labels"] is not None:
|
| 348 |
+
if loss_on_eos:
|
| 349 |
+
processed["labels"] = processed["labels"] + [self.eos_token_id]
|
| 350 |
+
else:
|
| 351 |
+
processed["labels"] = processed["labels"] + [-100]
|
| 352 |
+
if "data_ids" in processed:
|
| 353 |
+
processed["data_ids"] = processed["data_ids"] + [-1]
|
| 354 |
+
|
| 355 |
+
if not include_eos:
|
| 356 |
+
# check if EOS token is present
|
| 357 |
+
if processed["input_ids"][-1] == self.eos_token_id:
|
| 358 |
+
# remove EOS token
|
| 359 |
+
processed["input_ids"] = processed["input_ids"][:-1]
|
| 360 |
+
processed["attention_mask"] = processed["attention_mask"][:-1]
|
| 361 |
+
processed["labels"] = processed["labels"][:-1]
|
| 362 |
+
processed["example_inds"] = processed["example_inds"][:-1]
|
| 363 |
+
processed["data_ids"] = processed["data_ids"][:-1]
|
| 364 |
+
|
| 365 |
+
return processed
|
| 366 |
+
|
| 367 |
+
def tokenize_messages(
|
| 368 |
+
self,
|
| 369 |
+
messages: List[Dict[str, Any]],
|
| 370 |
+
loss_on_start_token: bool = False,
|
| 371 |
+
loss_on_eos: bool = False,
|
| 372 |
+
include_eos: bool = True,
|
| 373 |
+
) -> Dict[str, Any]:
|
| 374 |
+
"""
|
| 375 |
+
Intended for tokenize from messages to tokenized texts with the loss applied.
|
| 376 |
+
"""
|
| 377 |
+
# First convert messages to text with loss computation flags
|
| 378 |
+
texts = self.messages_to_loss_texts(messages, loss_on_start_token)
|
| 379 |
+
|
| 380 |
+
# Then tokenize the texts
|
| 381 |
+
return self.tokenize_loss_texts(texts, loss_on_eos, include_eos = include_eos)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# Register tokenizer classes for AutoTokenizer
|
| 387 |
+
AutoTokenizer.register("GemmaChatTokenizer", slow_tokenizer_class=None, fast_tokenizer_class=GemmaChatTokenizer)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
if __name__ == "__main__":
|
| 391 |
+
# Example usage
|
| 392 |
+
# for first load
|
| 393 |
+
custom_tokenizer = GemmaChatTokenizer.from_gemma_pretrained("google/gemma-3-1b-it")
|
| 394 |
+
|
| 395 |
+
# for subsequent loads
|
| 396 |
+
# custom_tokenizer = GemmaChatTokenizer.from_pretrained("tsor13/chat-gemma-12b-pt")
|
| 397 |
+
# custom_tokenizer = GemmaChatTokenizer.from_pretrained("repos/chat-gemma-12b-pt")
|
| 398 |
+
|
| 399 |
+
# Test messages in role/content format
|
| 400 |
+
test_messages = [
|
| 401 |
+
[
|
| 402 |
+
{"role": "description", "content": "This is a test task"},
|
| 403 |
+
{"role": "input", "content": "What is 2+2?"},
|
| 404 |
+
{"role": "output", "content": "4"},
|
| 405 |
+
{"role": "input", "content": "What is 3+3?"},
|
| 406 |
+
],
|
| 407 |
+
[
|
| 408 |
+
{"role": "description", "content": "This is a test task"},
|
| 409 |
+
{"role": "output", "content": "4"},
|
| 410 |
+
{"role": "output", "content": "10"},
|
| 411 |
+
{"role": "output", "content": "13"},
|
| 412 |
+
],
|
| 413 |
+
[
|
| 414 |
+
{"role": "output", "content": "4"},
|
| 415 |
+
{"role": "output", "content": "10"},
|
| 416 |
+
{"role": "output", "content": "13"},
|
| 417 |
+
],
|
| 418 |
+
[
|
| 419 |
+
{"role": "input", "content": "What is 2+2?"},
|
| 420 |
+
{"role": "output", "content": "4"},
|
| 421 |
+
{"role": "input", "content": "What is 3+3?"},
|
| 422 |
+
{"role": "output", "content": "10"},
|
| 423 |
+
{"role": "input", "content": "What is 4+4?"},
|
| 424 |
+
],
|
| 425 |
+
]
|
| 426 |
+
for messages in test_messages:
|
| 427 |
+
# get messages to text_loss
|
| 428 |
+
texts = custom_tokenizer.messages_to_loss_texts(messages)
|
| 429 |
+
|
| 430 |
+
print("Texts with loss flags:")
|
| 431 |
+
for i, text in enumerate(texts):
|
| 432 |
+
print(f" {i}: {text}")
|
| 433 |
+
|
| 434 |
+
text = custom_tokenizer.messages_to_text(messages, start_generation=True)
|
| 435 |
+
print(f"\nFull text with generation prompt:")
|
| 436 |
+
print(text)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
print("\nTesting save/load cycle:")
|
| 440 |
+
# Test saving and loading
|
| 441 |
+
tokenizer_path = "repos/chat-gemma-tokenizer"
|
| 442 |
+
custom_tokenizer.save_pretrained(tokenizer_path)
|
| 443 |
+
print("Tokenizer saved successfully!")
|
| 444 |
+
|
| 445 |
+
# also save this file in the tokenizer_path
|
| 446 |
+
import shutil
|
| 447 |
+
shutil.copy(__file__, os.path.join(tokenizer_path, "gemma_chat_tokenizer.py"))
|
| 448 |
+
print("GemmaChatTokenizer.py saved successfully!")
|
generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 2,
|
| 3 |
+
"cache_implementation": "hybrid",
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
1,
|
| 7 |
+
106
|
| 8 |
+
],
|
| 9 |
+
"pad_token_id": 0,
|
| 10 |
+
"top_k": 64,
|
| 11 |
+
"top_p": 0.95,
|
| 12 |
+
"transformers_version": "4.51.3"
|
| 13 |
+
}
|
model-00001-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6cef91d3c32c85e95d0b452e190b7701b85ee4ff313d483349b24290f62afee
|
| 3 |
+
size 4979902192
|
model-00002-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0accacd83d92614807758dfc0661690fc9f3d9f6556ae8b46e61de5a70d3bdbf
|
| 3 |
+
size 4931296592
|
model-00003-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4147b87360deb84e430607bf3686e5ec29d2ab175978d8be42294151f7b6e26d
|
| 3 |
+
size 4931296656
|
model-00004-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14c35cdcf195c4ac45cdaacb811f0f8c7e4130de125de48aea9c700bb9789a45
|
| 3 |
+
size 4931296656
|
model-00005-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e287f5a199f0d42dac15a5434426f96ea2e2c44a6ac1b6981a1c73338f67aa6
|
| 3 |
+
size 4601000928
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"boi_token": "<start_of_image>",
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"content": "<bos>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
"eoi_token": "<end_of_image>",
|
| 11 |
+
"eos_token": {
|
| 12 |
+
"content": "<eos>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false
|
| 17 |
+
},
|
| 18 |
+
"image_token": "<image_soft_token>",
|
| 19 |
+
"pad_token": {
|
| 20 |
+
"content": "<pad>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false
|
| 25 |
+
},
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"content": "<unk>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": false,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
|
| 3 |
+
size 33384568
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24572f35d542bcfa4879d51a8c08dbd9d362bc088e00db35b31d5234790d8fe5
|
| 3 |
+
size 7377
|