DavidSeyserHF commited on
Commit
3abc4f7
·
verified ·
1 Parent(s): a16eb9a

Update rex1-base: mixed-2 checkpoint step 710000

Browse files
Files changed (6) hide show
  1. README.md +3 -3
  2. export_metadata.json +2 -2
  3. inference.py +7 -0
  4. model.py +39 -0
  5. model.safetensors +1 -1
  6. training_config.yaml +139 -24
README.md CHANGED
@@ -7,7 +7,7 @@ tags:
7
  - rex
8
  ---
9
 
10
- # REX1 Step 29000
11
 
12
  REX is a recursive decoder-only Transformer language model. This repository uses custom
13
  Transformers code, so load it with `trust_remote_code=True`.
@@ -21,12 +21,12 @@ tokenizer = AutoTokenizer.from_pretrained(".")
21
 
22
  ## Checkpoint
23
 
24
- Exported from `runs/rex-300m/ckpt_step29000.pt`.
25
 
26
  ## Training Notes
27
 
28
  - Tokenizer: `gpt2`
29
  - Context length: `1024`
30
- - Training output dir: `runs/rex-300m`
31
 
32
  This is a base language model checkpoint and is not instruction-aligned unless noted.
 
7
  - rex
8
  ---
9
 
10
+ # REX1 300M mixed-2 step 710000
11
 
12
  REX is a recursive decoder-only Transformer language model. This repository uses custom
13
  Transformers code, so load it with `trust_remote_code=True`.
 
21
 
22
  ## Checkpoint
23
 
24
+ Exported from `runs/rex-300m-mixed-2/ckpt_step710000.pt`.
25
 
26
  ## Training Notes
27
 
28
  - Tokenizer: `gpt2`
29
  - Context length: `1024`
30
+ - Training output dir: `runs/rex-300m-mixed-2`
31
 
32
  This is a base language model checkpoint and is not instruction-aligned unless noted.
export_metadata.json CHANGED
@@ -1,4 +1,4 @@
1
  {
2
- "checkpoint": "runs/rex-300m/ckpt_step29000.pt",
3
- "step": 29000
4
  }
 
1
  {
2
+ "checkpoint": "runs/rex-300m-mixed-2/ckpt_step710000.pt",
3
+ "step": 710000
4
  }
inference.py CHANGED
@@ -45,6 +45,12 @@ def build_parser() -> argparse.ArgumentParser:
45
  parser.add_argument("--max-new-tokens", type=int, default=100, help="Number of tokens to generate")
46
  parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature; 0 means greedy")
47
  parser.add_argument("--top-k", type=int, default=50, help="Limit sampling to top-k tokens; <=0 disables")
 
 
 
 
 
 
48
  return parser
49
 
50
 
@@ -73,6 +79,7 @@ def main() -> None:
73
  max_new_tokens=args.max_new_tokens,
74
  temperature=args.temperature,
75
  top_k=top_k,
 
76
  )
77
 
78
  print(tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True))
 
45
  parser.add_argument("--max-new-tokens", type=int, default=100, help="Number of tokens to generate")
46
  parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature; 0 means greedy")
47
  parser.add_argument("--top-k", type=int, default=50, help="Limit sampling to top-k tokens; <=0 disables")
48
+ parser.add_argument(
49
+ "--no-repeat-ngram-size",
50
+ type=int,
51
+ default=0,
52
+ help="Prevent repeated n-grams of this size; 0 disables",
53
+ )
54
  return parser
55
 
56
 
 
79
  max_new_tokens=args.max_new_tokens,
80
  temperature=args.temperature,
81
  top_k=top_k,
82
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
83
  )
84
 
85
  print(tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True))
model.py CHANGED
@@ -211,11 +211,15 @@ class RexForCausalLM(nn.Module):
211
  max_new_tokens: int,
212
  temperature: float = 1.0,
213
  top_k: int | None = None,
 
214
  ) -> torch.Tensor:
215
  self.eval()
 
 
216
  for _ in range(max_new_tokens):
217
  context = input_ids[:, -self.cfg.max_seq_len :]
218
  logits = self(context)["logits"][:, -1, :]
 
219
  if temperature < 0:
220
  raise ValueError("temperature must be >= 0")
221
  if temperature == 0:
@@ -231,6 +235,41 @@ class RexForCausalLM(nn.Module):
231
  input_ids = torch.cat([input_ids, next_token], dim=1)
232
  return input_ids
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  def parameter_count(self, trainable_only: bool = False) -> int:
235
  params = self.parameters()
236
  if trainable_only:
 
211
  max_new_tokens: int,
212
  temperature: float = 1.0,
213
  top_k: int | None = None,
214
+ no_repeat_ngram_size: int = 0,
215
  ) -> torch.Tensor:
216
  self.eval()
217
+ if no_repeat_ngram_size < 0:
218
+ raise ValueError("no_repeat_ngram_size must be >= 0")
219
  for _ in range(max_new_tokens):
220
  context = input_ids[:, -self.cfg.max_seq_len :]
221
  logits = self(context)["logits"][:, -1, :]
222
+ logits = self._apply_no_repeat_ngram(logits, input_ids, no_repeat_ngram_size)
223
  if temperature < 0:
224
  raise ValueError("temperature must be >= 0")
225
  if temperature == 0:
 
235
  input_ids = torch.cat([input_ids, next_token], dim=1)
236
  return input_ids
237
 
238
+ @staticmethod
239
+ def _apply_no_repeat_ngram(
240
+ logits: torch.Tensor,
241
+ input_ids: torch.Tensor,
242
+ no_repeat_ngram_size: int,
243
+ ) -> torch.Tensor:
244
+ if no_repeat_ngram_size <= 0:
245
+ return logits
246
+
247
+ logits = logits.clone()
248
+ for batch_idx in range(input_ids.size(0)):
249
+ banned_tokens = RexForCausalLM._get_banned_ngram_tokens(
250
+ input_ids[batch_idx].tolist(),
251
+ no_repeat_ngram_size,
252
+ )
253
+ if banned_tokens:
254
+ logits[batch_idx, banned_tokens] = float("-inf")
255
+ return logits
256
+
257
+ @staticmethod
258
+ def _get_banned_ngram_tokens(tokens: list[int], ngram_size: int) -> list[int]:
259
+ if ngram_size == 1:
260
+ return list(set(tokens))
261
+ if len(tokens) < ngram_size - 1:
262
+ return []
263
+
264
+ prefix_to_next: dict[tuple[int, ...], set[int]] = {}
265
+ for i in range(len(tokens) - ngram_size + 1):
266
+ ngram = tokens[i : i + ngram_size]
267
+ prefix = tuple(ngram[:-1])
268
+ prefix_to_next.setdefault(prefix, set()).add(ngram[-1])
269
+
270
+ current_prefix = tuple(tokens[-(ngram_size - 1) :])
271
+ return list(prefix_to_next.get(current_prefix, set()))
272
+
273
  def parameter_count(self, trainable_only: bool = False) -> int:
274
  params = self.parameters()
275
  if trainable_only:
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:da14810666664369aa1f2981b658c9b39e6b31a25b3b8c10085a940e37d52cf6
3
  size 1196009344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5ed9638191454c97255326a3f5aed401e13c60f55f142286c247306d9698770
3
  size 1196009344
training_config.yaml CHANGED
@@ -15,32 +15,145 @@ data:
15
  tokenizer_name: gpt2
16
  block_size: 1024
17
  stride: 1024
18
- train_bin: data/train.bin
19
- val_bin: data/val.bin
20
  num_workers: 2
21
  download:
22
- dataset_name: HuggingFaceFW/fineweb-edu
23
- dataset_config: sample-10BT
24
- text_column: text
25
- train_split: train
26
- val_split: null
27
- split_strategy: head
28
- val_fraction: 0.005
29
- streaming: true
30
- seed: 1337
31
- max_train_docs: 50000
32
- max_val_docs: 10000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  train:
34
  seed: 1337
35
  device: auto
36
  dtype: bfloat16
37
- out_dir: runs/rex-300m
38
  batch_size: 8
39
  gradient_accumulation_steps: 1
40
- epochs: 20
41
  max_steps: null
42
- learning_rate: 0.0003
43
- min_lr: 3.0e-05
44
  warmup_steps: 1000
45
  weight_decay: 0.1
46
  betas:
@@ -49,22 +162,24 @@ train:
49
  eps: 1.0e-08
50
  grad_clip: 1.0
51
  compile: true
52
- resume: null
53
  log_every: 10
54
- eval_every: 500
55
- eval_batches: 50
56
- save_every: 1000
57
  wandb:
58
  enabled: true
59
  project: rex
60
  entity: null
61
- name: rex1
62
  group: pretrain
63
  tags:
64
  - recursive-transformer
65
  - 300m
66
- - fineweb-edu
67
- notes: null
 
 
68
  mode: online
69
  watch: false
70
  watch_log: gradients
 
15
  tokenizer_name: gpt2
16
  block_size: 1024
17
  stride: 1024
18
+ train_bin: data/mixed-2/train.bin
19
+ val_bin: data/mixed-2/val.bin
20
  num_workers: 2
21
  download:
22
+ sources:
23
+ - name: fineweb_edu
24
+ dataset_name: HuggingFaceFW/fineweb-edu
25
+ dataset_config: sample-10BT
26
+ text_column: text
27
+ train_split: train
28
+ split_strategy: head
29
+ streaming: true
30
+ max_train_docs: 400000
31
+ max_val_docs: 10000
32
+ - name: cosmopedia_web
33
+ dataset_name: HuggingFaceTB/cosmopedia
34
+ dataset_config: web_samples_v2
35
+ text_column: text
36
+ train_split: train
37
+ split_strategy: head
38
+ streaming: true
39
+ max_train_docs: 50000
40
+ max_val_docs: 5000
41
+ - name: cosmopedia_khanacademy
42
+ dataset_name: HuggingFaceTB/cosmopedia
43
+ dataset_config: khanacademy
44
+ text_column: text
45
+ train_split: train
46
+ split_strategy: head
47
+ streaming: true
48
+ max_train_docs: 50000
49
+ max_val_docs: 5000
50
+ - name: cosmopedia_openstax
51
+ dataset_name: HuggingFaceTB/cosmopedia
52
+ dataset_config: openstax
53
+ text_column: text
54
+ train_split: train
55
+ split_strategy: head
56
+ streaming: true
57
+ max_train_docs: 50000
58
+ max_val_docs: 5000
59
+ - name: cosmopedia_auto_math
60
+ dataset_name: HuggingFaceTB/cosmopedia
61
+ dataset_config: auto_math_text
62
+ text_column: text
63
+ train_split: train
64
+ split_strategy: head
65
+ streaming: true
66
+ max_train_docs: 50000
67
+ max_val_docs: 5000
68
+ - name: cosmopedia_stanford
69
+ dataset_name: HuggingFaceTB/cosmopedia
70
+ dataset_config: stanford
71
+ text_column: text
72
+ train_split: train
73
+ split_strategy: head
74
+ streaming: true
75
+ max_train_docs: 50000
76
+ max_val_docs: 5000
77
+ - name: cosmopedia_wikihow
78
+ dataset_name: HuggingFaceTB/cosmopedia
79
+ dataset_config: wikihow
80
+ text_column: text
81
+ train_split: train
82
+ split_strategy: head
83
+ streaming: true
84
+ max_train_docs: 40000
85
+ max_val_docs: 4000
86
+ - name: wikipedia_en
87
+ dataset_name: wikimedia/wikipedia
88
+ dataset_config: 20231101.en
89
+ text_column: text
90
+ train_split: train
91
+ split_strategy: head
92
+ streaming: true
93
+ max_train_docs: 80000
94
+ max_val_docs: 5000
95
+ - name: open_web_math
96
+ dataset_name: open-web-math/open-web-math
97
+ text_column: text
98
+ train_split: train
99
+ split_strategy: head
100
+ streaming: true
101
+ max_train_docs: 75000
102
+ max_val_docs: 5000
103
+ - name: codeparrot_clean
104
+ dataset_name: codeparrot/codeparrot-clean
105
+ text_column: content
106
+ train_split: train
107
+ split_strategy: head
108
+ streaming: true
109
+ max_train_docs: 34000
110
+ max_val_docs: 2500
111
+ - name: tinystories
112
+ dataset_name: roneneldan/TinyStories
113
+ text_column: text
114
+ train_split: train
115
+ val_split: validation
116
+ split_strategy: head
117
+ streaming: true
118
+ max_train_docs: 200000
119
+ max_val_docs: 5000
120
+ - name: wikitext103
121
+ dataset_name: wikitext
122
+ dataset_config: wikitext-103-raw-v1
123
+ text_column: text
124
+ train_split: train
125
+ val_split: validation
126
+ split_strategy: head
127
+ streaming: false
128
+ max_train_docs: 100000
129
+ max_val_docs: 5000
130
+ - name: arxiv_abstracts
131
+ dataset_name: nick007x/arxiv-papers
132
+ text_column:
133
+ - title
134
+ - subjects
135
+ - abstract
136
+ text_template: 'Title: {title}
137
+
138
+ Subjects: {subjects}
139
+
140
+ Abstract: {abstract}'
141
+ train_split: train
142
+ split_strategy: head
143
+ streaming: true
144
+ max_train_docs: 30000
145
+ max_val_docs: 5000
146
  train:
147
  seed: 1337
148
  device: auto
149
  dtype: bfloat16
150
+ out_dir: runs/rex-300m-mixed-2
151
  batch_size: 8
152
  gradient_accumulation_steps: 1
153
+ epochs: 10
154
  max_steps: null
155
+ learning_rate: 5.0e-05
156
+ min_lr: 5.0e-06
157
  warmup_steps: 1000
158
  weight_decay: 0.1
159
  betas:
 
162
  eps: 1.0e-08
163
  grad_clip: 1.0
164
  compile: true
165
+ resume: runs/rex-300m-mixed-continue/ckpt_step690000.pt
166
  log_every: 10
167
+ eval_every: 5000
168
+ eval_batches: 100
169
+ save_every: 10000
170
  wandb:
171
  enabled: true
172
  project: rex
173
  entity: null
174
+ name: rex1-mixed-2
175
  group: pretrain
176
  tags:
177
  - recursive-transformer
178
  - 300m
179
+ - mixed-2
180
+ - benchmark-mix
181
+ notes: "v2 corpus \u2014 more FineWeb-Edu + Wikipedia, less code/math. Continue\
182
+ \ from mixed-continue step 690k."
183
  mode: online
184
  watch: false
185
  watch_log: gradients