summerMC commited on
Commit
ec757bc
·
verified ·
1 Parent(s): c8b6819

Upload folder using huggingface_hub

Browse files
__pycache__/modeling_van_fast.cpython-312.pyc ADDED
Binary file (21.2 kB). View file
 
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VanFastForCausalLM"
4
+ ],
5
+ "block_size": 1024,
6
+ "bos_token_id": 50256,
7
+ "d_ff": 4096,
8
+ "d_model": 1024,
9
+ "dropout": 0.0,
10
+ "dtype": "float32",
11
+ "eos_token_id": 50256,
12
+ "initializer_range": 0.02,
13
+ "is_decoder": true,
14
+ "model_type": "van_fast_transformer",
15
+ "n_head": 16,
16
+ "n_kv_head": 4,
17
+ "n_layer": 18,
18
+ "pad_token_id": 50256,
19
+ "tie_word_embeddings": false,
20
+ "transformers_version": "5.0.0",
21
+ "use_qk_norm": true,
22
+ "vocab_size": 50257,
23
+ "auto_map": {
24
+ "AutoConfig": "modeling_van_fast.VanFastConfig",
25
+ "AutoModelForCausalLM": "modeling_van_fast.VanFastForCausalLM"
26
+ },
27
+ "torch_dtype": "bfloat16",
28
+ "use_cache": true
29
+ }
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "output_attentions": false,
6
+ "output_hidden_states": false,
7
+ "pad_token_id": 50256,
8
+ "transformers_version": "5.0.0"
9
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aca8a9d4b041994c006a3dcde0de9d0279a46e270cf0111af07fd4eb1da64f40
3
+ size 1506599392
modeling_van_fast.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+ from transformers.generation import GenerationMixin
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+
10
+
11
+ def safe_tensor(x, clamp=30.0):
12
+ x = torch.nan_to_num(
13
+ x,
14
+ nan=0.0,
15
+ posinf=clamp,
16
+ neginf=-clamp,
17
+ )
18
+ x = torch.clamp(x, min=-clamp, max=clamp)
19
+ return x
20
+
21
+
22
+ class VanFastConfig(PretrainedConfig):
23
+ model_type = "van_fast_transformer"
24
+
25
+ def __init__(
26
+ self,
27
+ vocab_size=50257,
28
+ block_size=1024,
29
+ d_model=1024,
30
+ n_layer=18,
31
+ n_head=16,
32
+ n_kv_head=4,
33
+ d_ff=4096,
34
+ dropout=0.0,
35
+ use_qk_norm=True,
36
+ initializer_range=0.02,
37
+ pad_token_id=None,
38
+ eos_token_id=None,
39
+ bos_token_id=None,
40
+ use_cache=True,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(
44
+ pad_token_id=pad_token_id,
45
+ eos_token_id=eos_token_id,
46
+ bos_token_id=bos_token_id,
47
+ **kwargs,
48
+ )
49
+
50
+ self.vocab_size = vocab_size
51
+ self.block_size = block_size
52
+ self.d_model = d_model
53
+ self.n_layer = n_layer
54
+ self.n_head = n_head
55
+ self.n_kv_head = n_kv_head
56
+ self.d_ff = d_ff
57
+ self.dropout = dropout
58
+ self.use_qk_norm = use_qk_norm
59
+ self.initializer_range = initializer_range
60
+
61
+ self.is_decoder = True
62
+ self.is_encoder_decoder = False
63
+ self.tie_word_embeddings = False
64
+ self.use_cache = use_cache
65
+
66
+
67
+ class HFRMSNorm(nn.Module):
68
+ def __init__(self, dim: int, eps: float = 1e-6):
69
+ super().__init__()
70
+ self.eps = eps
71
+ self.weight = nn.Parameter(torch.ones(dim))
72
+
73
+ def forward(self, x):
74
+ x = safe_tensor(x, clamp=30.0)
75
+
76
+ x_float = x.float()
77
+ var = x_float.pow(2).mean(dim=-1, keepdim=True)
78
+ var = torch.nan_to_num(var, nan=1.0, posinf=1.0, neginf=1.0)
79
+ var = torch.clamp(var, min=0.0, max=1e6)
80
+
81
+ y = x_float * torch.rsqrt(var + self.eps)
82
+ y = y.to(dtype=x.dtype) * self.weight.to(dtype=x.dtype)
83
+ y = safe_tensor(y, clamp=30.0)
84
+
85
+ return y
86
+
87
+
88
+ class HFRotaryEmbedding(nn.Module):
89
+ def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0):
90
+ super().__init__()
91
+
92
+ inv_freq = 1.0 / (
93
+ base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
94
+ )
95
+
96
+ t = torch.arange(max_seq_len, dtype=torch.float32)
97
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
98
+
99
+ cos = freqs.cos()
100
+ sin = freqs.sin()
101
+
102
+ self.register_buffer("cos_cached", cos[None, None, :, :], persistent=False)
103
+ self.register_buffer("sin_cached", sin[None, None, :, :], persistent=False)
104
+
105
+ def forward(self, x, seq_len: int, offset: int = 0):
106
+ end = offset + seq_len
107
+
108
+ max_len = self.cos_cached.shape[2]
109
+ if end > max_len:
110
+ # block_sizeを超えた場合は最後の範囲に丸める
111
+ offset = max(0, max_len - seq_len)
112
+ end = offset + seq_len
113
+
114
+ cos = self.cos_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype)
115
+ sin = self.sin_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype)
116
+
117
+ return cos, sin
118
+
119
+
120
+ def hf_apply_rope(q, k, cos, sin):
121
+ q1 = q[..., ::2]
122
+ q2 = q[..., 1::2]
123
+
124
+ k1 = k[..., ::2]
125
+ k2 = k[..., 1::2]
126
+
127
+ q_rot = torch.stack(
128
+ [
129
+ q1 * cos - q2 * sin,
130
+ q1 * sin + q2 * cos,
131
+ ],
132
+ dim=-1,
133
+ ).flatten(-2)
134
+
135
+ k_rot = torch.stack(
136
+ [
137
+ k1 * cos - k2 * sin,
138
+ k1 * sin + k2 * cos,
139
+ ],
140
+ dim=-1,
141
+ ).flatten(-2)
142
+
143
+ q_rot = safe_tensor(q_rot, clamp=10.0)
144
+ k_rot = safe_tensor(k_rot, clamp=10.0)
145
+
146
+ return q_rot, k_rot
147
+
148
+
149
+ class HFGQAAttention(nn.Module):
150
+ def __init__(self, config: VanFastConfig):
151
+ super().__init__()
152
+
153
+ d_model = config.d_model
154
+ n_head = config.n_head
155
+ n_kv_head = config.n_kv_head
156
+
157
+ assert d_model % n_head == 0
158
+ assert n_head % n_kv_head == 0
159
+
160
+ self.d_model = d_model
161
+ self.n_head = n_head
162
+ self.n_kv_head = n_kv_head
163
+ self.head_dim = d_model // n_head
164
+ self.num_groups = n_head // n_kv_head
165
+ self.dropout = config.dropout
166
+ self.block_size = config.block_size
167
+
168
+ assert self.head_dim % 2 == 0
169
+
170
+ self.q_proj = nn.Linear(d_model, n_head * self.head_dim, bias=False)
171
+ self.k_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False)
172
+ self.v_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False)
173
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
174
+
175
+ if config.use_qk_norm:
176
+ self.q_norm = HFRMSNorm(self.head_dim)
177
+ self.k_norm = HFRMSNorm(self.head_dim)
178
+ else:
179
+ self.q_norm = nn.Identity()
180
+ self.k_norm = nn.Identity()
181
+
182
+ self.rope = HFRotaryEmbedding(
183
+ dim=self.head_dim,
184
+ max_seq_len=config.block_size,
185
+ )
186
+
187
+ def forward(
188
+ self,
189
+ x,
190
+ past_key_value=None,
191
+ use_cache=False,
192
+ ):
193
+ x = safe_tensor(x, clamp=30.0)
194
+
195
+ B, T, C = x.shape
196
+
197
+ q = self.q_proj(x)
198
+ k = self.k_proj(x)
199
+ v = self.v_proj(x)
200
+
201
+ q = safe_tensor(q, clamp=30.0)
202
+ k = safe_tensor(k, clamp=30.0)
203
+ v = safe_tensor(v, clamp=30.0)
204
+
205
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
206
+ k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
207
+ v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
208
+
209
+ q = self.q_norm(q)
210
+ k = self.k_norm(k)
211
+
212
+ q = safe_tensor(q, clamp=10.0)
213
+ k = safe_tensor(k, clamp=10.0)
214
+ v = safe_tensor(v, clamp=30.0)
215
+
216
+ past_len = 0
217
+
218
+ if past_key_value is not None:
219
+ past_k, past_v = past_key_value
220
+ past_len = past_k.shape[2]
221
+
222
+ cos, sin = self.rope(q, T, offset=past_len)
223
+ q, k = hf_apply_rope(q, k, cos, sin)
224
+
225
+ if past_key_value is not None:
226
+ past_k, past_v = past_key_value
227
+ k = torch.cat([past_k, k], dim=2)
228
+ v = torch.cat([past_v, v], dim=2)
229
+
230
+ # cache長をblock_size以内に制限
231
+ if k.shape[2] > self.block_size:
232
+ k = k[:, :, -self.block_size:, :].contiguous()
233
+ v = v[:, :, -self.block_size:, :].contiguous()
234
+
235
+ present_key_value = (k, v) if use_cache else None
236
+
237
+ k_attn = k
238
+ v_attn = v
239
+
240
+ if self.num_groups > 1:
241
+ k_attn = k_attn.repeat_interleave(self.num_groups, dim=1)
242
+ v_attn = v_attn.repeat_interleave(self.num_groups, dim=1)
243
+
244
+ # prefill時はcausal、decode時はqueryが最新1tokenなので全cacheへattend可能
245
+ is_causal = past_key_value is None
246
+
247
+ y = F.scaled_dot_product_attention(
248
+ q,
249
+ k_attn,
250
+ v_attn,
251
+ attn_mask=None,
252
+ dropout_p=self.dropout if self.training else 0.0,
253
+ is_causal=is_causal,
254
+ )
255
+
256
+ y = safe_tensor(y, clamp=30.0)
257
+
258
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
259
+ y = self.o_proj(y)
260
+ y = safe_tensor(y, clamp=30.0)
261
+
262
+ return y, present_key_value
263
+
264
+
265
+ class HFSwiGLU(nn.Module):
266
+ def __init__(self, config: VanFastConfig):
267
+ super().__init__()
268
+
269
+ self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False)
270
+ self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False)
271
+ self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False)
272
+
273
+ def forward(self, x):
274
+ x = safe_tensor(x, clamp=30.0)
275
+
276
+ a = self.w1(x)
277
+ b = self.w3(x)
278
+
279
+ a = safe_tensor(a, clamp=30.0)
280
+ b = safe_tensor(b, clamp=30.0)
281
+
282
+ y = F.silu(a) * b
283
+ y = safe_tensor(y, clamp=30.0)
284
+
285
+ y = self.w2(y)
286
+ y = safe_tensor(y, clamp=30.0)
287
+
288
+ return y
289
+
290
+
291
+ class HFDecoderBlock(nn.Module):
292
+ def __init__(self, config: VanFastConfig):
293
+ super().__init__()
294
+
295
+ self.attn_norm = HFRMSNorm(config.d_model)
296
+ self.attn = HFGQAAttention(config)
297
+
298
+ self.ffn_norm = HFRMSNorm(config.d_model)
299
+ self.ffn = HFSwiGLU(config)
300
+
301
+ def forward(
302
+ self,
303
+ x,
304
+ past_key_value=None,
305
+ use_cache=False,
306
+ ):
307
+ x = safe_tensor(x, clamp=30.0)
308
+
309
+ a, present_key_value = self.attn(
310
+ self.attn_norm(x),
311
+ past_key_value=past_key_value,
312
+ use_cache=use_cache,
313
+ )
314
+
315
+ a = safe_tensor(a, clamp=30.0)
316
+ x = safe_tensor(x + a, clamp=30.0)
317
+
318
+ f = self.ffn(self.ffn_norm(x))
319
+ f = safe_tensor(f, clamp=30.0)
320
+ x = safe_tensor(x + f, clamp=30.0)
321
+
322
+ return x, present_key_value
323
+
324
+
325
+ class VanFastForCausalLM(PreTrainedModel, GenerationMixin):
326
+ config_class = VanFastConfig
327
+ base_model_prefix = "van_fast"
328
+ supports_gradient_checkpointing = False
329
+ _supports_cache_class = False
330
+
331
+ def __init__(self, config: VanFastConfig):
332
+ super().__init__(config)
333
+
334
+ self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
335
+ self.drop = nn.Dropout(config.dropout)
336
+
337
+ self.blocks = nn.ModuleList([
338
+ HFDecoderBlock(config)
339
+ for _ in range(config.n_layer)
340
+ ])
341
+
342
+ self.norm = HFRMSNorm(config.d_model)
343
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
344
+
345
+ self.post_init()
346
+
347
+ def _init_weights(self, module):
348
+ std = getattr(self.config, "initializer_range", 0.02)
349
+
350
+ if isinstance(module, nn.Linear):
351
+ nn.init.normal_(module.weight, mean=0.0, std=std)
352
+ if module.bias is not None:
353
+ nn.init.zeros_(module.bias)
354
+
355
+ elif isinstance(module, nn.Embedding):
356
+ nn.init.normal_(module.weight, mean=0.0, std=std)
357
+
358
+ def get_input_embeddings(self):
359
+ return self.token_emb
360
+
361
+ def set_input_embeddings(self, value):
362
+ self.token_emb = value
363
+
364
+ def get_output_embeddings(self):
365
+ return self.lm_head
366
+
367
+ def set_output_embeddings(self, new_embeddings):
368
+ self.lm_head = new_embeddings
369
+
370
+ def _normalize_past(self, past_key_values):
371
+ if past_key_values is None:
372
+ return [None] * len(self.blocks)
373
+
374
+ if isinstance(past_key_values, tuple):
375
+ past_key_values = list(past_key_values)
376
+
377
+ if len(past_key_values) < len(self.blocks):
378
+ past_key_values = past_key_values + [None] * (
379
+ len(self.blocks) - len(past_key_values)
380
+ )
381
+
382
+ return past_key_values
383
+
384
+ def forward(
385
+ self,
386
+ input_ids=None,
387
+ labels=None,
388
+ attention_mask=None,
389
+ past_key_values=None,
390
+ use_cache=None,
391
+ return_dict=True,
392
+ **kwargs,
393
+ ):
394
+ if input_ids is None:
395
+ raise ValueError("input_ids is required")
396
+
397
+ if use_cache is None:
398
+ use_cache = getattr(self.config, "use_cache", True)
399
+
400
+ has_past = past_key_values is not None
401
+
402
+ # cache使用時は新規tokenだけ処理
403
+ if has_past and input_ids.shape[1] > 1:
404
+ input_ids = input_ids[:, -1:]
405
+
406
+ # cacheなしのprefill時だけblock_sizeに丸める
407
+ if not has_past and input_ids.shape[1] > self.config.block_size:
408
+ input_ids = input_ids[:, -self.config.block_size:]
409
+ if labels is not None:
410
+ labels = labels[:, -self.config.block_size:]
411
+
412
+ past_key_values = self._normalize_past(past_key_values)
413
+
414
+ x = self.token_emb(input_ids)
415
+ x = safe_tensor(x, clamp=30.0)
416
+
417
+ x = self.drop(x)
418
+
419
+ presents = [] if use_cache else None
420
+
421
+ for i, block in enumerate(self.blocks):
422
+ layer_past = past_key_values[i]
423
+
424
+ x, present = block(
425
+ x,
426
+ past_key_value=layer_past,
427
+ use_cache=use_cache,
428
+ )
429
+
430
+ if use_cache:
431
+ presents.append(present)
432
+
433
+ x = self.norm(x)
434
+ x = safe_tensor(x, clamp=30.0)
435
+
436
+ logits = self.lm_head(x)
437
+
438
+ logits = logits.float()
439
+ logits = torch.nan_to_num(
440
+ logits,
441
+ nan=0.0,
442
+ posinf=80.0,
443
+ neginf=-80.0,
444
+ )
445
+ logits = torch.clamp(logits, min=-80.0, max=80.0)
446
+
447
+ loss = None
448
+
449
+ if labels is not None:
450
+ shift_logits = logits[:, :-1, :].contiguous()
451
+ shift_labels = labels[:, 1:].contiguous()
452
+
453
+ if shift_logits.numel() > 0:
454
+ loss = F.cross_entropy(
455
+ shift_logits.view(-1, shift_logits.size(-1)),
456
+ shift_labels.view(-1),
457
+ ignore_index=-100,
458
+ )
459
+
460
+ past_out = tuple(presents) if use_cache else None
461
+
462
+ if not return_dict:
463
+ if loss is None:
464
+ return (logits, past_out)
465
+ return (loss, logits, past_out)
466
+
467
+ return CausalLMOutputWithPast(
468
+ loss=loss,
469
+ logits=logits,
470
+ past_key_values=past_out,
471
+ hidden_states=None,
472
+ attentions=None,
473
+ )
474
+
475
+ def prepare_inputs_for_generation(
476
+ self,
477
+ input_ids,
478
+ past_key_values=None,
479
+ attention_mask=None,
480
+ use_cache=True,
481
+ **kwargs,
482
+ ):
483
+ if past_key_values is not None:
484
+ input_ids = input_ids[:, -1:]
485
+ else:
486
+ if input_ids.shape[1] > self.config.block_size:
487
+ input_ids = input_ids[:, -self.config.block_size:]
488
+
489
+ return {
490
+ "input_ids": input_ids,
491
+ "attention_mask": attention_mask,
492
+ "past_key_values": past_key_values,
493
+ "use_cache": use_cache,
494
+ }
495
+
496
+ def _reorder_cache(self, past_key_values, beam_idx):
497
+ if past_key_values is None:
498
+ return None
499
+
500
+ reordered = []
501
+
502
+ for layer_past in past_key_values:
503
+ if layer_past is None:
504
+ reordered.append(None)
505
+ continue
506
+
507
+ k, v = layer_past
508
+ reordered.append(
509
+ (
510
+ k.index_select(0, beam_idx.to(k.device)),
511
+ v.index_select(0, beam_idx.to(v.device)),
512
+ )
513
+ )
514
+
515
+ return tuple(reordered)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1024,
9
+ "pad_token": "<|endoftext|>",
10
+ "tokenizer_class": "GPT2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }
training_cfg.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "OUT_DIR": "/content/van_fast_transformer",
3
+ "TOKENIZER_NAME": "gpt2",
4
+ "DATASET_NAME": "HuggingFaceFW/fineweb-edu",
5
+ "DATASET_CONFIG": "sample-10BT",
6
+ "DATASET_SPLIT": "train",
7
+ "TEXT_KEY": "text",
8
+ "VOCAB_SIZE": 50257,
9
+ "BLOCK_SIZE": 1024,
10
+ "D_MODEL": 1024,
11
+ "N_LAYER": 18,
12
+ "N_HEAD": 16,
13
+ "N_KV_HEAD": 4,
14
+ "D_FF": 4096,
15
+ "DROPOUT": 0.0,
16
+ "USE_QK_NORM": true,
17
+ "MAX_STEPS": 5000,
18
+ "BATCH_SIZE": 1,
19
+ "GRAD_ACCUM": 4,
20
+ "LR": 0.0003,
21
+ "MIN_LR": 3e-05,
22
+ "WARMUP_STEPS": 300,
23
+ "WEIGHT_DECAY": 0.1,
24
+ "BETA1": 0.9,
25
+ "BETA2": 0.95,
26
+ "MAX_GRAD_NORM": 1.0,
27
+ "EARLY_STOP_LOSS": 0.0001,
28
+ "EARLY_STOP_PATIENCE": 1,
29
+ "EARLY_STOP_SAVE": true,
30
+ "EARLY_STOP_ON_EVAL": false,
31
+ "EARLY_STOP_EVAL_LOSS": 0.0001,
32
+ "EARLY_STOP_EVAL_PATIENCE": 2,
33
+ "LOG_EVERY": 10,
34
+ "EVAL_EVERY": 1000,
35
+ "SAVE_EVERY": 1000,
36
+ "EVAL_BATCHES": 4,
37
+ "GEN_MAX_NEW_TOKENS": 160,
38
+ "GEN_TEMPERATURE": 0.8,
39
+ "GEN_TOP_K": 50,
40
+ "GEN_TOP_P": 0.95,
41
+ "SEED": 42,
42
+ "DTYPE": "bf16",
43
+ "TF32": true,
44
+ "COMPILE": true,
45
+ "GRADIENT_CHECKPOINTING": false,
46
+ "NUM_WORKERS": 2,
47
+ "PIN_MEMORY": true,
48
+ "DEBUG_SMALL": false
49
+ }