Initial upload of fine‑tuned Gemma + custom tokenizer
Browse files- README.md +46 -31
- gemma_chat_tokenizer.py +14 -0
README.md
CHANGED
|
@@ -37,15 +37,19 @@ print(formatted_prompt) # start_generation adds the <start_of_turn> token to con
|
|
| 37 |
```
|
| 38 |
Output:
|
| 39 |
```
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
```
|
| 46 |
The data for the model to emulate / generate is wrapped in `<start_of_turn>` / `<end_of_turn>` tokens.
|
| 47 |
-
Description and input is not wrapped in anything. Thus, do not expect the model to generate these tokens - instead focus on the wrapped output tokens.
|
| 48 |
-
Messages are separated by newlines.
|
| 49 |
|
| 50 |
In training, loss is ONLY calculated on the output tokens and the `<end_of_turn>` token. Thus, the model is only designed to generate / predict probabilities after `<start_of_turn>` and until `<end_of_turn>` - everything else is out of distribution for the model and not recommended.
|
| 51 |
|
|
@@ -70,17 +74,17 @@ with torch.no_grad():
|
|
| 70 |
|
| 71 |
Output:
|
| 72 |
```
|
| 73 |
-
Top 10 probabilities for first output token:
|
| 74 |
-
1. 'Tokyo' -> 0.
|
| 75 |
-
2. 'Tok' -> 0.
|
| 76 |
-
3. '
|
| 77 |
-
4. '
|
| 78 |
-
5. '
|
| 79 |
-
6. '
|
| 80 |
-
7. '
|
| 81 |
-
8. '
|
| 82 |
-
9. '
|
| 83 |
-
10. '
|
| 84 |
```
|
| 85 |
|
| 86 |
Great! Almost all of the probability mass is on the correct answer, Tokyo.
|
|
@@ -110,10 +114,10 @@ for i in range(n_gens):
|
|
| 110 |
|
| 111 |
Outputs:
|
| 112 |
```
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
```
|
| 118 |
Not too bad!
|
| 119 |
|
|
@@ -164,14 +168,25 @@ for i in range(n_gens):
|
|
| 164 |
Some example generations:
|
| 165 |
```
|
| 166 |
Generation 0:
|
| 167 |
-
|
| 168 |
-
of power and danger,
|
| 169 |
Generation 1:
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
Generation 2:
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
Generation 3:
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
```
|
| 176 |
|
| 177 |
|
|
@@ -208,10 +223,10 @@ for i in range(n_gens):
|
|
| 208 |
```
|
| 209 |
Output:
|
| 210 |
```
|
| 211 |
-
{"situation": "
|
| 212 |
-
{"situation": "
|
| 213 |
-
{"situation": "
|
| 214 |
-
{"situation": "
|
| 215 |
```
|
| 216 |
|
| 217 |
A few tips and tricks:
|
|
|
|
| 37 |
```
|
| 38 |
Output:
|
| 39 |
```
|
| 40 |
+
<start_of_turn>user
|
| 41 |
+
Generate something that fits this description. Don't generate anything else, just the desired generation output.
|
| 42 |
+
Description: Capitals
|
| 43 |
+
|
| 44 |
+
France<end_of_turn>
|
| 45 |
+
<start_of_turn>model
|
| 46 |
+
Paris<end_of_turn>
|
| 47 |
+
<start_of_turn>user
|
| 48 |
+
Japan<end_of_turn>
|
| 49 |
+
<start_of_turn>model
|
| 50 |
+
|
| 51 |
```
|
| 52 |
The data for the model to emulate / generate is wrapped in `<start_of_turn>` / `<end_of_turn>` tokens.
|
|
|
|
|
|
|
| 53 |
|
| 54 |
In training, loss is ONLY calculated on the output tokens and the `<end_of_turn>` token. Thus, the model is only designed to generate / predict probabilities after `<start_of_turn>` and until `<end_of_turn>` - everything else is out of distribution for the model and not recommended.
|
| 55 |
|
|
|
|
| 74 |
|
| 75 |
Output:
|
| 76 |
```
|
| 77 |
+
Top 10 probabilities for first output token:
|
| 78 |
+
1. 'Tokyo' -> 0.9330
|
| 79 |
+
2. 'Tok' -> 0.0114
|
| 80 |
+
3. 'Ky' -> 0.0064
|
| 81 |
+
4. 'Washington' -> 0.0025
|
| 82 |
+
5. 'To' -> 0.0019
|
| 83 |
+
6. 'Japan' -> 0.0016
|
| 84 |
+
7. 'tok' -> 0.0014
|
| 85 |
+
8. 'N' -> 0.0013
|
| 86 |
+
9. 'K' -> 0.0012
|
| 87 |
+
10. 'Toy' -> 0.0011
|
| 88 |
```
|
| 89 |
|
| 90 |
Great! Almost all of the probability mass is on the correct answer, Tokyo.
|
|
|
|
| 114 |
|
| 115 |
Outputs:
|
| 116 |
```
|
| 117 |
+
Terraforming Mars
|
| 118 |
+
Scythe
|
| 119 |
+
Concordia
|
| 120 |
+
7 Wonders
|
| 121 |
```
|
| 122 |
Not too bad!
|
| 123 |
|
|
|
|
| 168 |
Some example generations:
|
| 169 |
```
|
| 170 |
Generation 0:
|
| 171 |
+
No content
|
|
|
|
| 172 |
Generation 1:
|
| 173 |
+
An underwater menace,
|
| 174 |
+
With a wide, dark mouth.
|
| 175 |
+
Silent in the deep sea,
|
| 176 |
+
A toothy and a fearsome south.
|
| 177 |
Generation 2:
|
| 178 |
+
Shivers of ocean, a silent dread,
|
| 179 |
+
Shadowed fin above your head.
|
| 180 |
+
Eyes of black, a piercing stare,
|
| 181 |
+
Hunting through the depths with care.
|
| 182 |
+
Jaws of power,
|
| 183 |
Generation 3:
|
| 184 |
+
Gleaming through the ocean blue,
|
| 185 |
+
A silent hunter, strong and true.
|
| 186 |
+
Sharp teeth and eyes of ancient might,
|
| 187 |
+
A shadow moving in the light.
|
| 188 |
+
|
| 189 |
+
With graceful fins it gl
|
| 190 |
```
|
| 191 |
|
| 192 |
|
|
|
|
| 223 |
```
|
| 224 |
Output:
|
| 225 |
```
|
| 226 |
+
{"situation": "You're in the cafeteria at school and your professor is behind you in the line.", "is_awkward": false}
|
| 227 |
+
{"situation": "During your walk home, you notice someone has lost their wallet and pick it up.", "is_awkward": false}
|
| 228 |
+
{"situation": "You're at the bar and someone approaches you.", "is_awkward": false}
|
| 229 |
+
{"situation": "Your friend reveals a secret you already knew but they didn't realize you did.", "is_awkward": false}
|
| 230 |
```
|
| 231 |
|
| 232 |
A few tips and tricks:
|
gemma_chat_tokenizer.py
CHANGED
|
@@ -160,6 +160,15 @@ class GemmaChatTokenizer(GemmaTokenizerFast):
|
|
| 160 |
if start_generation and chat_messages[-1]["role"] == "assistant":
|
| 161 |
chat_messages.append({"role": "user", "content": default_user_message})
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# Apply chat template
|
| 164 |
full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation)
|
| 165 |
# replace <bos> with nothing
|
|
@@ -259,6 +268,7 @@ class GemmaChatTokenizer(GemmaTokenizerFast):
|
|
| 259 |
all_texts = ''
|
| 260 |
example_inds = []
|
| 261 |
dataset_inds = []
|
|
|
|
| 262 |
|
| 263 |
for i, item in enumerate(texts):
|
| 264 |
processed = self(text=item["text"])
|
|
@@ -398,6 +408,10 @@ if __name__ == "__main__":
|
|
| 398 |
|
| 399 |
# Test messages in role/content format
|
| 400 |
test_messages = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
[
|
| 402 |
{"role": "description", "content": "This is a test task"},
|
| 403 |
{"role": "input", "content": "What is 2+2?"},
|
|
|
|
| 160 |
if start_generation and chat_messages[-1]["role"] == "assistant":
|
| 161 |
chat_messages.append({"role": "user", "content": default_user_message})
|
| 162 |
|
| 163 |
+
# if len(chat_messages) == 1:
|
| 164 |
+
# # change to user
|
| 165 |
+
# chat_messages[0]["role"] = "user"
|
| 166 |
+
# # add
|
| 167 |
+
# # TAYLOR - manual for now because of the way gemma handles only having a system prompt
|
| 168 |
+
if not has_input and len(chat_messages) == 1:
|
| 169 |
+
# add a default user message
|
| 170 |
+
chat_messages.append({"role": "user", "content": default_user_message})
|
| 171 |
+
|
| 172 |
# Apply chat template
|
| 173 |
full_text = self.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=start_generation)
|
| 174 |
# replace <bos> with nothing
|
|
|
|
| 268 |
all_texts = ''
|
| 269 |
example_inds = []
|
| 270 |
dataset_inds = []
|
| 271 |
+
|
| 272 |
|
| 273 |
for i, item in enumerate(texts):
|
| 274 |
processed = self(text=item["text"])
|
|
|
|
| 408 |
|
| 409 |
# Test messages in role/content format
|
| 410 |
test_messages = [
|
| 411 |
+
[
|
| 412 |
+
{"role": "description", "content": "Pick a number between 1 and 100"},
|
| 413 |
+
],
|
| 414 |
+
|
| 415 |
[
|
| 416 |
{"role": "description", "content": "This is a test task"},
|
| 417 |
{"role": "input", "content": "What is 2+2?"},
|