Girinath11 commited on
Commit
ee2236e
·
verified ·
1 Parent(s): 0ee651b

Upload 8 files

Browse files
merges (2).txt ADDED
The diff for this file is too large to render. See raw diff
 
mixture_of_recursion.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+ import math
7
+
8
+ @dataclass
9
+ class RecursiveLanguageModelConfig:
10
+ vocab_size: int = 50257
11
+ embedding_dim: int = 512
12
+ num_layers: int = 6
13
+ num_attention_heads: int = 8
14
+ max_recursion_steps: int = 5
15
+ max_position_embeddings: int = 512
16
+ hidden_dropout_prob: float = 0.1
17
+ attention_dropout_prob: float = 0.1
18
+ intermediate_size: int = 2048
19
+ layer_norm_eps: float = 1e-5
20
+
21
+ pad_token_id: int = 50256
22
+ bos_token_id: int = 50256
23
+ eos_token_id: int = 50256
24
+
25
+ simple_recursion_steps: int = 1
26
+ medium_recursion_steps: int = 3
27
+ complex_recursion_steps: int = 5
28
+
29
+ confidence_threshold: float = 0.8
30
+ use_adaptive_stopping: bool = True
31
+ initializer_range: float = 0.02
32
+
33
+
34
+ # Model Output class that supports subscripting
35
+ class ModelOutput:
36
+ def __init__(self, loss=None, logits=None, complexity_class=None, recursion_steps=None):
37
+ self.loss = loss
38
+ self.logits = logits
39
+ self.complexity_class = complexity_class
40
+ self.recursion_steps = recursion_steps
41
+
42
+ def __getitem__(self, key):
43
+ if isinstance(key, str):
44
+ return getattr(self, key)
45
+ elif isinstance(key, int):
46
+ # For subscript access like outputs[0], outputs[1]
47
+ items = [self.loss, self.logits, self.complexity_class, self.recursion_steps]
48
+ return items[key]
49
+ elif isinstance(key, slice):
50
+ items = [self.loss, self.logits, self.complexity_class, self.recursion_steps]
51
+ return items[key]
52
+
53
+ def __iter__(self):
54
+ return iter([self.loss, self.logits, self.complexity_class, self.recursion_steps])
55
+
56
+
57
+ class RotaryPositionalEmbedding(nn.Module):
58
+ def __init__(self, dim, max_seq_len=2048, base=10000):
59
+ super().__init__()
60
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
61
+ self.register_buffer('inv_freq', inv_freq)
62
+ self.max_seq_len = max_seq_len
63
+ self.dim = dim
64
+
65
+ def forward(self, seq_len, device):
66
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
67
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
68
+ emb = torch.cat((freqs, freqs), dim=-1)
69
+ return emb.cos(), emb.sin()
70
+
71
+
72
+ def apply_rotary_pos_emb(q, k, cos, sin):
73
+ def rotate_half(x):
74
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
75
+ return torch.cat((-x2, x1), dim=-1)
76
+
77
+ q_embed = (q * cos) + (rotate_half(q) * sin)
78
+ k_embed = (k * cos) + (rotate_half(k) * sin)
79
+ return q_embed, k_embed
80
+
81
+
82
+ class MultiHeadAttention(nn.Module):
83
+ def __init__(self, config: RecursiveLanguageModelConfig):
84
+ super().__init__()
85
+ self.num_heads = config.num_attention_heads
86
+ self.head_dim = config.embedding_dim // config.num_attention_heads
87
+ self.embed_dim = config.embedding_dim
88
+
89
+ assert self.embed_dim % self.num_heads == 0
90
+
91
+ self.q_proj = nn.Linear(config.embedding_dim, config.embedding_dim)
92
+ self.k_proj = nn.Linear(config.embedding_dim, config.embedding_dim)
93
+ self.v_proj = nn.Linear(config.embedding_dim, config.embedding_dim)
94
+ self.out_proj = nn.Linear(config.embedding_dim, config.embedding_dim)
95
+
96
+ self.dropout = nn.Dropout(config.attention_dropout_prob)
97
+ self.rotary_emb = RotaryPositionalEmbedding(self.head_dim, config.max_position_embeddings)
98
+
99
+ def forward(self, hidden_states, attention_mask=None):
100
+ batch_size, seq_len, _ = hidden_states.shape
101
+
102
+ q = self.q_proj(hidden_states)
103
+ k = self.k_proj(hidden_states)
104
+ v = self.v_proj(hidden_states)
105
+
106
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
107
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
108
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
109
+
110
+ cos, sin = self.rotary_emb(seq_len, hidden_states.device)
111
+ cos = cos[None, None, :, :].expand(batch_size, self.num_heads, -1, -1)
112
+ sin = sin[None, None, :, :].expand(batch_size, self.num_heads, -1, -1)
113
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
114
+
115
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
116
+
117
+ if attention_mask is not None:
118
+ attn_weights = attn_weights + attention_mask
119
+
120
+ attn_weights = F.softmax(attn_weights, dim=-1)
121
+ attn_weights = self.dropout(attn_weights)
122
+
123
+ attn_output = torch.matmul(attn_weights, v)
124
+ attn_output = attn_output.transpose(1, 2).contiguous()
125
+ attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
126
+ attn_output = self.out_proj(attn_output)
127
+
128
+ return attn_output
129
+
130
+
131
+ class FeedForward(nn.Module):
132
+ def __init__(self, config: RecursiveLanguageModelConfig):
133
+ super().__init__()
134
+ self.fc1 = nn.Linear(config.embedding_dim, config.intermediate_size)
135
+ self.fc2 = nn.Linear(config.intermediate_size, config.embedding_dim)
136
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = F.gelu(x)
141
+ x = self.dropout(x)
142
+ x = self.fc2(x)
143
+ x = self.dropout(x)
144
+ return x
145
+
146
+
147
+ class TransformerBlock(nn.Module):
148
+ def __init__(self, config: RecursiveLanguageModelConfig):
149
+ super().__init__()
150
+ self.attention = MultiHeadAttention(config)
151
+ self.feed_forward = FeedForward(config)
152
+ self.ln1 = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
153
+ self.ln2 = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
154
+
155
+ def forward(self, hidden_states, attention_mask=None):
156
+ residual = hidden_states
157
+ hidden_states = self.ln1(hidden_states)
158
+ hidden_states = self.attention(hidden_states, attention_mask)
159
+ hidden_states = residual + hidden_states
160
+
161
+ residual = hidden_states
162
+ hidden_states = self.ln2(hidden_states)
163
+ hidden_states = self.feed_forward(hidden_states)
164
+ hidden_states = residual + hidden_states
165
+
166
+ return hidden_states
167
+
168
+
169
+ class SequenceLevelRouter(nn.Module):
170
+ def __init__(self, config: RecursiveLanguageModelConfig):
171
+ super().__init__()
172
+ self.config = config
173
+
174
+ self.pooler = nn.Linear(config.embedding_dim, config.embedding_dim)
175
+ self.pooler_activation = nn.Tanh()
176
+
177
+ self.classifier = nn.Sequential(
178
+ nn.Linear(config.embedding_dim, config.embedding_dim // 2),
179
+ nn.GELU(),
180
+ nn.Dropout(0.1),
181
+ nn.Linear(config.embedding_dim // 2, 3)
182
+ )
183
+
184
+ def forward(self, hidden_states, attention_mask=None):
185
+ if attention_mask is not None:
186
+ mask_expanded = attention_mask.unsqueeze(-1).float()
187
+ sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
188
+ sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
189
+ pooled = sum_hidden / sum_mask
190
+ else:
191
+ pooled = hidden_states.mean(dim=1)
192
+
193
+ pooled = self.pooler(pooled)
194
+ pooled = self.pooler_activation(pooled)
195
+
196
+ complexity_logits = self.classifier(pooled)
197
+ complexity_class = torch.argmax(complexity_logits, dim=-1)
198
+
199
+ recursion_steps = torch.zeros_like(complexity_class)
200
+ recursion_steps[complexity_class == 0] = self.config.simple_recursion_steps
201
+ recursion_steps[complexity_class == 1] = self.config.medium_recursion_steps
202
+ recursion_steps[complexity_class == 2] = self.config.complex_recursion_steps
203
+
204
+ return complexity_logits, complexity_class, recursion_steps
205
+
206
+
207
+ class RecursionLayer(nn.Module):
208
+ def __init__(self, config: RecursiveLanguageModelConfig):
209
+ super().__init__()
210
+ self.transformer_block = TransformerBlock(config)
211
+
212
+ def forward(self, hidden_states, attention_mask=None):
213
+ return self.transformer_block(hidden_states, attention_mask)
214
+
215
+
216
+ class RecursiveLanguageModel(nn.Module):
217
+ def __init__(self, config: RecursiveLanguageModelConfig):
218
+ super().__init__()
219
+ self.config = config
220
+
221
+ self.embedding_layer = nn.Embedding(
222
+ config.vocab_size,
223
+ config.embedding_dim,
224
+ padding_idx=config.pad_token_id
225
+ )
226
+
227
+ self.base_transformer = nn.ModuleList([
228
+ TransformerBlock(config) for _ in range(config.num_layers)
229
+ ])
230
+
231
+ self.router = SequenceLevelRouter(config)
232
+ self.recursion_layer = RecursionLayer(config)
233
+
234
+ self.final_layer_norm = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps)
235
+ self.language_model_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
236
+
237
+ self.tie_weights()
238
+ self._init_weights()
239
+
240
+ def tie_weights(self):
241
+ self.language_model_head.weight = self.embedding_layer.weight
242
+
243
+ def _init_weights(self):
244
+ for module in self.modules():
245
+ if isinstance(module, nn.Linear):
246
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
247
+ if module.bias is not None:
248
+ module.bias.data.zero_()
249
+ elif isinstance(module, nn.Embedding):
250
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
251
+ if module.padding_idx is not None:
252
+ module.weight.data[module.padding_idx].zero_()
253
+ elif isinstance(module, nn.LayerNorm):
254
+ module.bias.data.zero_()
255
+ module.weight.data.fill_(1.0)
256
+
257
+ def get_attention_mask(self, input_ids):
258
+ batch_size, seq_len = input_ids.shape
259
+ device = input_ids.device
260
+
261
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
262
+ attention_mask = torch.zeros(batch_size, 1, seq_len, seq_len, device=device)
263
+ attention_mask[:, :, causal_mask] = float('-inf')
264
+
265
+ padding_mask = (input_ids == self.config.pad_token_id)
266
+ valid_mask = ~padding_mask
267
+
268
+ if padding_mask.any():
269
+ padding_mask_expanded = padding_mask.unsqueeze(1).unsqueeze(2)
270
+ attention_mask = attention_mask.masked_fill(padding_mask_expanded, float('-inf'))
271
+
272
+ return attention_mask, valid_mask
273
+
274
+ def forward(self, input_ids, labels=None, attention_mask=None):
275
+ batch_size, seq_len = input_ids.shape
276
+
277
+ hidden_states = self.embedding_layer(input_ids)
278
+ attn_mask, padding_mask = self.get_attention_mask(input_ids)
279
+
280
+ for layer in self.base_transformer:
281
+ hidden_states = layer(hidden_states, attn_mask)
282
+
283
+ complexity_logits, complexity_class, recursion_steps = self.router(
284
+ hidden_states, padding_mask
285
+ )
286
+
287
+ if self.training:
288
+ max_steps = self.config.complex_recursion_steps
289
+ for step in range(max_steps):
290
+ hidden_states = self.recursion_layer(hidden_states, attn_mask)
291
+ else:
292
+ max_steps_in_batch = int(recursion_steps.max().item())
293
+ for step in range(max_steps_in_batch):
294
+ step_mask = (recursion_steps > step).float().unsqueeze(-1).unsqueeze(-1)
295
+ new_hidden = self.recursion_layer(hidden_states, attn_mask)
296
+ hidden_states = step_mask * new_hidden + (1 - step_mask) * hidden_states
297
+
298
+ hidden_states = self.final_layer_norm(hidden_states)
299
+ logits = self.language_model_head(hidden_states)
300
+
301
+ loss = None
302
+ if labels is not None:
303
+ shift_logits = logits[..., :-1, :].contiguous()
304
+ shift_labels = labels[..., 1:].contiguous()
305
+
306
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
307
+ lm_loss = loss_fct(
308
+ shift_logits.view(-1, self.config.vocab_size),
309
+ shift_labels.view(-1)
310
+ )
311
+
312
+ complexity_value = min(max(seq_len // 170, 0), 2)
313
+ pseudo_labels = torch.full(
314
+ (batch_size,),
315
+ complexity_value,
316
+ dtype=torch.long,
317
+ device=input_ids.device
318
+ )
319
+
320
+ router_loss_fct = nn.CrossEntropyLoss()
321
+ router_loss = router_loss_fct(complexity_logits, pseudo_labels)
322
+
323
+ loss = lm_loss + 0.1 * router_loss
324
+
325
+ return ModelOutput(
326
+ loss=loss,
327
+ logits=logits,
328
+ complexity_class=complexity_class,
329
+ recursion_steps=recursion_steps
330
+ )
331
+
332
+ def generate(self, input_ids, max_new_tokens=50, temperature=1.0,
333
+ top_p=0.9, do_sample=True):
334
+ self.eval()
335
+ generated = input_ids
336
+
337
+ for _ in range(max_new_tokens):
338
+ with torch.no_grad():
339
+ outputs = self.forward(generated)
340
+ logits = outputs.logits
341
+
342
+ next_token_logits = logits[:, -1, :] / temperature
343
+
344
+ if do_sample:
345
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
346
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
347
+
348
+ sorted_indices_to_remove = cumulative_probs > top_p
349
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
350
+ sorted_indices_to_remove[..., 0] = 0
351
+
352
+ indices_to_remove = sorted_indices_to_remove.scatter(
353
+ 1, sorted_indices, sorted_indices_to_remove
354
+ )
355
+ next_token_logits[indices_to_remove] = float('-inf')
356
+
357
+ probs = F.softmax(next_token_logits, dim=-1)
358
+ next_token = torch.multinomial(probs, num_samples=1)
359
+ else:
360
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
361
+
362
+ generated = torch.cat([generated, next_token], dim=-1)
363
+
364
+ if next_token.item() == self.config.eos_token_id:
365
+ break
366
+
367
+ return generated
368
+
369
+ def save_pretrained(self, save_directory):
370
+ import os
371
+ import json
372
+
373
+ os.makedirs(save_directory, exist_ok=True)
374
+ torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
375
+
376
+ config_dict = {
377
+ 'vocab_size': self.config.vocab_size,
378
+ 'embedding_dim': self.config.embedding_dim,
379
+ 'num_layers': self.config.num_layers,
380
+ 'num_attention_heads': self.config.num_attention_heads,
381
+ 'max_recursion_steps': self.config.max_recursion_steps,
382
+ 'max_position_embeddings': self.config.max_position_embeddings,
383
+ 'hidden_dropout_prob': self.config.hidden_dropout_prob,
384
+ 'attention_dropout_prob': self.config.attention_dropout_prob,
385
+ 'intermediate_size': self.config.intermediate_size,
386
+ 'layer_norm_eps': self.config.layer_norm_eps,
387
+ 'pad_token_id': self.config.pad_token_id,
388
+ 'bos_token_id': self.config.bos_token_id,
389
+ 'eos_token_id': self.config.eos_token_id,
390
+ 'simple_recursion_steps': self.config.simple_recursion_steps,
391
+ 'medium_recursion_steps': self.config.medium_recursion_steps,
392
+ 'complex_recursion_steps': self.config.complex_recursion_steps,
393
+ 'confidence_threshold': self.config.confidence_threshold,
394
+ 'use_adaptive_stopping': self.config.use_adaptive_stopping,
395
+ 'initializer_range': self.config.initializer_range,
396
+ }
397
+
398
+ with open(os.path.join(save_directory, 'config.json'), 'w') as f:
399
+ json.dump(config_dict, f, indent=2)
400
+
401
+ @classmethod
402
+ def from_pretrained(cls, load_directory, device='cpu'):
403
+ import os
404
+ import json
405
+
406
+ config_path = os.path.join(load_directory, 'config.json')
407
+ with open(config_path, 'r') as f:
408
+ config_dict = json.load(f)
409
+
410
+ config = RecursiveLanguageModelConfig(**config_dict)
411
+ model = cls(config)
412
+
413
+ weights_path = os.path.join(load_directory, 'pytorch_model.bin')
414
+ state_dict = torch.load(weights_path, map_location=device)
415
+ model.load_state_dict(state_dict)
416
+
417
+ model.to(device)
418
+ return model
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a940c0a5f02c276105651421d196092982e55bfb2b8d0c55240de58569a1a197
3
+ size 192826915
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config (2).json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "pad_token": "<|endoftext|>",
19
+ "tokenizer_class": "GPT2Tokenizer",
20
+ "unk_token": "<|endoftext|>"
21
+ }
train (2).py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
+ from datasets import load_dataset, interleave_datasets
6
+ from mixture_of_recursion import RecursiveLanguageModel, RecursiveLanguageModelConfig
7
+ import gc
8
+
9
+ # Configuration
10
+ TOTAL_SAMPLES = 50000
11
+ BATCH_SIZE = 1
12
+ GRAD_ACCUM = 32
13
+ EPOCHS = 3
14
+ LEARNING_RATE = 3e-4
15
+ MAX_LENGTH = 384
16
+
17
+ print("Starting training with 50K premium samples")
18
+ print("-" * 60)
19
+
20
+ # Load tokenizer
21
+ print("\nLoading tokenizer...")
22
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+ print(f"Tokenizer vocab size: {len(tokenizer)}")
26
+ print(f"Pad token ID: {tokenizer.pad_token_id}")
27
+
28
+ # Load datasets
29
+ print("\nLoading datasets...")
30
+ print(" FineWeb-Edu (45%)")
31
+ fineweb = load_dataset(
32
+ "HuggingFaceFW/fineweb-edu",
33
+ name="sample-10BT",
34
+ split="train",
35
+ streaming=True
36
+ ).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.45))
37
+
38
+ print(" Cosmopedia (30%)")
39
+ cosmopedia = load_dataset(
40
+ "HuggingFaceTB/cosmopedia",
41
+ "web_samples_v1",
42
+ split="train",
43
+ streaming=True
44
+ ).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.30))
45
+
46
+ print(" OpenWebText (25%)")
47
+ openwebtext = load_dataset(
48
+ "openwebtext",
49
+ split="train",
50
+ streaming=True
51
+ ).shuffle(seed=42).take(int(TOTAL_SAMPLES * 0.25))
52
+
53
+ # Mix datasets
54
+ print("\nMixing datasets...")
55
+ train_dataset = interleave_datasets(
56
+ [fineweb, cosmopedia, openwebtext],
57
+ probabilities=[0.45, 0.30, 0.25],
58
+ seed=42
59
+ )
60
+
61
+ # Tokenization function
62
+ def tokenize(examples):
63
+ if 'text' in examples:
64
+ texts = examples['text']
65
+ elif 'content' in examples:
66
+ texts = examples['content']
67
+ else:
68
+ texts = list(examples.values())[0]
69
+
70
+ return tokenizer(
71
+ texts,
72
+ truncation=True,
73
+ max_length=MAX_LENGTH,
74
+ padding=False
75
+ )
76
+
77
+ # Tokenize datasets
78
+ print("Tokenizing...")
79
+ tokenized_train = train_dataset.map(
80
+ tokenize,
81
+ batched=True,
82
+ remove_columns=train_dataset.column_names
83
+ ).filter(lambda x: len(x['input_ids']) >= 128)
84
+
85
+ # Validation set
86
+ val_dataset = load_dataset(
87
+ "HuggingFaceFW/fineweb-edu",
88
+ name="sample-10BT",
89
+ split="train",
90
+ streaming=True
91
+ ).take(1000)
92
+
93
+ val_tokenized = val_dataset.map(
94
+ tokenize,
95
+ batched=True,
96
+ remove_columns=val_dataset.column_names
97
+ ).filter(lambda x: len(x['input_ids']) >= 128)
98
+
99
+ # Build model
100
+ print("\nBuilding model...")
101
+ config = RecursiveLanguageModelConfig(
102
+ vocab_size=len(tokenizer),
103
+ embedding_dim=512,
104
+ num_layers=6,
105
+ num_attention_heads=8,
106
+ max_recursion_steps=5,
107
+ max_position_embeddings=512,
108
+ intermediate_size=2048,
109
+ pad_token_id=tokenizer.pad_token_id,
110
+ bos_token_id=tokenizer.pad_token_id,
111
+ eos_token_id=tokenizer.pad_token_id,
112
+ simple_recursion_steps=1,
113
+ medium_recursion_steps=3,
114
+ complex_recursion_steps=5,
115
+ use_adaptive_stopping=True,
116
+ hidden_dropout_prob=0.1,
117
+ attention_dropout_prob=0.1
118
+ )
119
+
120
+ model = RecursiveLanguageModel(config)
121
+
122
+ params = sum(p.numel() for p in model.parameters()) / 1e6
123
+ print(f"Model parameters: {params:.1f}M")
124
+
125
+ # Clear cache
126
+ torch.cuda.empty_cache()
127
+ gc.collect()
128
+
129
+ # Training setup
130
+ data_collator = DataCollatorForLanguageModeling(
131
+ tokenizer=tokenizer,
132
+ mlm=False
133
+ )
134
+
135
+ steps_per_epoch = TOTAL_SAMPLES // (BATCH_SIZE * GRAD_ACCUM)
136
+ max_steps = steps_per_epoch * EPOCHS
137
+
138
+ print(f"\nTraining steps: {max_steps}")
139
+ print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
140
+
141
+ training_args = TrainingArguments(
142
+ output_dir="./checkpoints",
143
+ max_steps=max_steps,
144
+ per_device_train_batch_size=BATCH_SIZE,
145
+ per_device_eval_batch_size=BATCH_SIZE,
146
+ gradient_accumulation_steps=GRAD_ACCUM,
147
+ learning_rate=LEARNING_RATE,
148
+ weight_decay=0.01,
149
+ warmup_steps=500,
150
+ fp16=True,
151
+ logging_steps=100,
152
+ eval_strategy="steps",
153
+ eval_steps=1000,
154
+ save_steps=1000,
155
+ save_total_limit=2,
156
+ load_best_model_at_end=True,
157
+ metric_for_best_model="eval_loss",
158
+ report_to="none",
159
+ max_grad_norm=1.0,
160
+ save_safetensors=False, # Use PyTorch format instead of safetensors
161
+ )
162
+
163
+ # Custom trainer with perplexity
164
+ class CustomTrainer(Trainer):
165
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
166
+ outputs = model(**inputs)
167
+ return (outputs.loss, outputs) if return_outputs else outputs.loss
168
+
169
+ def evaluation_loop(self, dataloader, description, prediction_loss_only=None,
170
+ ignore_keys=None, metric_key_prefix="eval"):
171
+ output = super().evaluation_loop(
172
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
173
+ )
174
+
175
+ if output.metrics.get(f"{metric_key_prefix}_loss") is not None:
176
+ try:
177
+ perplexity = math.exp(output.metrics[f"{metric_key_prefix}_loss"])
178
+ output.metrics[f"{metric_key_prefix}_perplexity"] = perplexity
179
+ except OverflowError:
180
+ output.metrics[f"{metric_key_prefix}_perplexity"] = float("inf")
181
+
182
+ return output
183
+
184
+ def training_step(self, model, inputs, num_items_in_batch=None):
185
+ loss = super().training_step(model, inputs, num_items_in_batch)
186
+
187
+ if self.state.global_step % 50 == 0:
188
+ torch.cuda.empty_cache()
189
+
190
+ return loss
191
+
192
+ trainer = CustomTrainer(
193
+ model=model,
194
+ args=training_args,
195
+ train_dataset=tokenized_train,
196
+ eval_dataset=val_tokenized,
197
+ data_collator=data_collator
198
+ )
199
+
200
+ # Train
201
+ print("\nStarting training...")
202
+ print("-" * 60)
203
+
204
+ try:
205
+ trainer.train()
206
+
207
+ # Final evaluation
208
+ print("\nFinal evaluation...")
209
+ metrics = trainer.evaluate()
210
+
211
+ print("\n" + "="*60)
212
+ print("FINAL RESULTS:")
213
+ print("="*60)
214
+ print(f"Evaluation Loss: {metrics['eval_loss']:.4f}")
215
+
216
+ if 'eval_perplexity' in metrics:
217
+ print(f"Perplexity: {metrics['eval_perplexity']:.2f}")
218
+ else:
219
+ try:
220
+ perplexity = math.exp(metrics['eval_loss'])
221
+ print(f"Perplexity: {perplexity:.2f}")
222
+ except OverflowError:
223
+ print(f"Perplexity: inf (loss too high)")
224
+ print("="*60 + "\n")
225
+
226
+ # Save with custom method (handles tied weights properly)
227
+ print("Saving model...")
228
+ model.save_pretrained("./recursive-lm")
229
+ tokenizer.save_pretrained("./recursive-lm")
230
+ print("Model saved successfully!")
231
+
232
+ except KeyboardInterrupt:
233
+ print("\n\nTraining interrupted by user")
234
+ print("Saving current model state...")
235
+ model.save_pretrained("./recursive-lm-interrupted")
236
+ tokenizer.save_pretrained("./recursive-lm-interrupted")
237
+
238
+ except Exception as e:
239
+ print(f"\n\nTraining stopped due to: {e}")
240
+ import traceback
241
+ traceback.print_exc()
242
+
243
+ # Try to save anyway
244
+ try:
245
+ print("\nAttempting to save model...")
246
+ model.save_pretrained("./recursive-lm-error")
247
+ tokenizer.save_pretrained("./recursive-lm-error")
248
+ print("Model saved!")
249
+ except:
250
+ print("Could not save model")
251
+
252
+ print("\nTraining complete!")
vocab (2).json ADDED
The diff for this file is too large to render. See raw diff