MhaWay commited on
Commit
23af9e0
·
verified ·
1 Parent(s): cc29f95

Create scripts/train_veronica.py

Browse files
Files changed (1) hide show
  1. scripts/train_veronica.py +633 -0
scripts/train_veronica.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Pretrain Veronica-Polymorphic from scratch (clean mixture: FinePDFs / DCLM / FineWeb-Edu).
5
+
6
+ Basic example:
7
+ python veronica-polymorphic/scripts/train_veronica.py \
8
+ --config veronica-polymorphic/configs/veronica-pretrain-12L.json \
9
+ --dataset_paths data/mix_optimal_50_30_20_2048 \
10
+ --output_dir veronica-polymorphic/runs/veronica-pretrain-vMix-2048 \
11
+ --per_device_train_batch_size 4 \
12
+ --gradient_accumulation_steps 4 \
13
+ --learning_rate 2e-4 \
14
+ --label_smoothing 0.01 \
15
+ --rep_alpha 0.0 \
16
+ --max_steps 60000 \
17
+ --max_seq_len 2048
18
+
19
+ You can use different datasets (e.g., 512 / 1024 / 2048) in separate runs for length curriculum.
20
+ """
21
+
22
+ import os
23
+ import re
24
+ import glob
25
+ import json
26
+ import math
27
+ import argparse
28
+ import random
29
+ from dataclasses import dataclass
30
+ from typing import Dict, List, Optional
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from datasets import load_from_disk
35
+ from transformers import (
36
+ AutoTokenizer,
37
+ Trainer,
38
+ TrainingArguments,
39
+ TrainerCallback,
40
+ CONFIG_MAPPING,
41
+ MODEL_FOR_CAUSAL_LM_MAPPING,
42
+ LogitsProcessorList,
43
+ NoRepeatNGramLogitsProcessor,
44
+ RepetitionPenaltyLogitsProcessor,
45
+ )
46
+
47
+ # --- Veronica bindings ---
48
+ from veronica.configuration_veronica import VeronicaConfig
49
+ from veronica.modeling_veronica import VeronicaForCausalLM
50
+ from veronica.modeling_components import Fp32LayerNorm
51
+
52
+ CONFIG_MAPPING.register("veronica", VeronicaConfig)
53
+ MODEL_FOR_CAUSAL_LM_MAPPING.register(VeronicaConfig, VeronicaForCausalLM)
54
+
55
+ # Disable CUDA Graphs (HF Trainer + torch.compile may conflict sometimes)
56
+ os.environ.setdefault("TORCH_COMPILE_USE_CUDAGRAPHS", "0")
57
+ os.environ.setdefault("TORCHINDUCTOR_DISABLE_CUDAGRAPHS", "1")
58
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
59
+
60
+
61
+ # ===========================
62
+ # Utility
63
+ # ===========================
64
+
65
+ def find_latest_checkpoint(run_dir: str) -> Optional[str]:
66
+ ckpts = glob.glob(os.path.join(run_dir, "checkpoint-*"))
67
+ if not ckpts:
68
+ return None
69
+ ckpts.sort(key=lambda p: int(re.search(r"checkpoint-(\d+)", p).group(1)))
70
+ return ckpts[-1]
71
+
72
+
73
+ def build_tokenizer(candidates: List[str], save_dir: str) -> AutoTokenizer:
74
+ """
75
+ Try to load an existing tokenizer from the provided paths;
76
+ otherwise fallback to gpt2 and add basic special tokens.
77
+ """
78
+ tok = None
79
+ for p in candidates:
80
+ if os.path.exists(p):
81
+ try:
82
+ tok = AutoTokenizer.from_pretrained(p, use_fast=True)
83
+ print(f"[tokenizer] loaded from {p}")
84
+ break
85
+ except Exception:
86
+ pass
87
+ if tok is None:
88
+ print("[tokenizer] fallback: gpt2")
89
+ tok = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
90
+
91
+ specials: Dict[str, str] = {}
92
+ if tok.eos_token is None:
93
+ specials["eos_token"] = "<|eos|>"
94
+ if tok.pad_token is None:
95
+ specials["pad_token"] = "<|pad|>"
96
+ if tok.bos_token is None:
97
+ specials["bos_token"] = "<|bos|>"
98
+
99
+ if specials:
100
+ tok.add_special_tokens(specials)
101
+
102
+ tok.save_pretrained(save_dir)
103
+ tok = AutoTokenizer.from_pretrained(save_dir, use_fast=True)
104
+ base_vocab = tok.vocab_size
105
+ effective_vocab = len(tok)
106
+ print(
107
+ f"[tokenizer] base_vocab={base_vocab} added={effective_vocab - base_vocab} "
108
+ f"effective_vocab={effective_vocab} eos={tok.eos_token_id} "
109
+ f"pad={tok.pad_token_id} bos={tok.bos_token_id}"
110
+ )
111
+ return tok
112
+
113
+
114
+ def load_cfg_with_vocab(cfg_path: str, tok: AutoTokenizer) -> VeronicaConfig:
115
+ """
116
+ Load the config and adapt it to the tokenizer vocabulary.
117
+ Model is designed as UN-TIED (lm_head != wte).
118
+ """
119
+ with open(cfg_path, "r", encoding="utf-8") as f:
120
+ d = json.load(f)
121
+ cfg = VeronicaConfig(**d)
122
+ cfg.model_type = "veronica"
123
+ cfg.vocab_size = int(len(tok))
124
+ # untied model: no tie_word_embeddings
125
+ return cfg
126
+
127
+
128
+ def init_model_from_config(cfg: VeronicaConfig, tok: AutoTokenizer) -> VeronicaForCausalLM:
129
+ model = VeronicaForCausalLM(cfg)
130
+ use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
131
+ dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)
132
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
133
+ model.to(dtype=dtype, device=device)
134
+
135
+ effective_vocab = len(tok)
136
+ emb = model.get_input_embeddings().weight
137
+ head = model.lm_head.weight
138
+
139
+ # Align embedding/head to the effective vocab
140
+ if emb.shape[0] != effective_vocab or head.shape[0] != effective_vocab:
141
+ old_vocab = emb.shape[0]
142
+ print(f"[model] resize_token_embeddings: {old_vocab} -> {effective_vocab}")
143
+ model.resize_token_embeddings(effective_vocab)
144
+ with torch.no_grad():
145
+ new_emb = model.get_input_embeddings().weight
146
+ new_head = model.lm_head.weight
147
+ mean_emb = new_emb[:old_vocab].mean(dim=0, keepdim=True)
148
+ mean_head = new_head[:old_vocab].mean(dim=0, keepdim=True)
149
+ if effective_vocab > old_vocab:
150
+ new_emb[old_vocab:] = mean_emb
151
+ new_head[old_vocab:] = mean_head
152
+
153
+ # Keep LayerNorm params in float32 (after global cast)
154
+ for m in model.modules():
155
+ if isinstance(m, Fp32LayerNorm):
156
+ m.ln.to(dtype=torch.float32)
157
+
158
+ model.config.use_cache = False
159
+ n_params = sum(p.numel() for p in model.parameters())
160
+ print(f"[model] params={n_params:,} vocab={effective_vocab}")
161
+ return model
162
+
163
+
164
+ def load_mix_dataset(path: str):
165
+ """
166
+ Load a packed dataset (train/validation) from disk.
167
+ Expected HuggingFace formats: a DatasetDict with 'train' and 'validation',
168
+ or a single Dataset that gets split 99/1.
169
+ """
170
+ ds = load_from_disk(path)
171
+ if isinstance(ds, dict) and "train" in ds and "validation" in ds:
172
+ return ds["train"], ds["validation"]
173
+ split = ds.train_test_split(test_size=0.01, seed=42)
174
+ return split["train"], split["test"]
175
+
176
+
177
+ # ===========================
178
+ # Collator
179
+ # ===========================
180
+
181
+ @dataclass
182
+ class CausalCollator:
183
+ tokenizer: AutoTokenizer
184
+ mask_runs: bool = False
185
+ run_len: int = 4
186
+ max_seq_len: Optional[int] = None # target length (e.g., 512/1024/2048)
187
+
188
+ def _mask_degenerate_runs(self, labels: torch.Tensor):
189
+ """
190
+ Mask degenerate runs (e.g., '____', '....') with length >= run_len.
191
+ Mostly legacy; can be left off with a clean dataset.
192
+ """
193
+ try:
194
+ id_us = self.tokenizer.encode("_", add_special_tokens=False)[0]
195
+ id_dot = self.tokenizer.encode(".", add_special_tokens=False)[0]
196
+ except Exception:
197
+ return
198
+ B, T = labels.size()
199
+ for b in range(B):
200
+ cnt_u = cnt_d = 0
201
+ for t in range(T):
202
+ tok = int(labels[b, t].item())
203
+ if tok == id_us:
204
+ cnt_u += 1
205
+ cnt_d = 0
206
+ elif tok == id_dot:
207
+ cnt_d += 1
208
+ cnt_u = 0
209
+ else:
210
+ cnt_u = cnt_d = 0
211
+ if cnt_u >= self.run_len or cnt_d >= self.run_len:
212
+ labels[b, t] = -100
213
+
214
+ def _crop(self, ids: torch.Tensor) -> torch.Tensor:
215
+ """
216
+ If max_seq_len is set and the sequence is longer,
217
+ crop a random window of length max_seq_len.
218
+ """
219
+ if self.max_seq_len is None:
220
+ return ids
221
+ L = ids.size(0)
222
+ if L <= self.max_seq_len:
223
+ return ids
224
+ start = random.randint(0, L - self.max_seq_len)
225
+ end = start + self.max_seq_len
226
+ return ids[start:end]
227
+
228
+ def __call__(self, features):
229
+ ids_list = []
230
+ for f in features:
231
+ ids = torch.tensor(f["input_ids"], dtype=torch.long)
232
+ ids = self._crop(ids)
233
+ ids_list.append(ids)
234
+
235
+ pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
236
+ ids = torch.nn.utils.rnn.pad_sequence(ids_list, batch_first=True, padding_value=pad_id)
237
+ attn = torch.where(ids == pad_id, 0, 1)
238
+
239
+ labels = ids.clone()
240
+ labels[labels == pad_id] = -100
241
+ if self.mask_runs:
242
+ self._mask_degenerate_runs(labels)
243
+
244
+ return {"input_ids": ids, "attention_mask": attn, "labels": labels}
245
+
246
+
247
+ # ===========================
248
+ # Callback Router + Smoke eval
249
+ # ===========================
250
+
251
+ SMOKE_PROMPTS = [
252
+ "The world we live in today is",
253
+ "Understanding complex ideas requires",
254
+ "Human intelligence differs from artificial intelligence because",
255
+ "A good system design is based on",
256
+ "In the middle of every difficulty lies",
257
+ "Once upon a time, there was a scientist who",
258
+ ]
259
+
260
+
261
+ class RouterAndSmokeCallback(TrainerCallback):
262
+ def __init__(self, tok: AutoTokenizer):
263
+ self.tok = tok
264
+
265
+ def on_log(self, args, state, control, **kwargs):
266
+ model = kwargs.get("model", None)
267
+ if model is None:
268
+ return
269
+ try:
270
+ if hasattr(model, "router_alpha_mean") and model.router_alpha_mean is not None:
271
+ alpha = model.router_alpha_mean.detach().float().cpu()
272
+ p = alpha / alpha.sum()
273
+ ent = -(p * (p.clamp_min(1e-9)).log()).sum()
274
+ ent_norm = float(ent / math.log(len(p)))
275
+ print(f"[router] alpha={alpha.tolist()} entropy_norm={ent_norm:.4f}")
276
+ except Exception:
277
+ pass
278
+
279
+ def on_evaluate(self, args, state, control, **kwargs):
280
+ model = kwargs.get("model", None)
281
+ if model is None:
282
+ return
283
+ model.eval()
284
+ dev = next(model.parameters()).device
285
+
286
+ prompt = random.choice(SMOKE_PROMPTS)
287
+ ids = self.tok(prompt, return_tensors="pt").to(dev)
288
+
289
+ processors = LogitsProcessorList([
290
+ NoRepeatNGramLogitsProcessor(3),
291
+ RepetitionPenaltyLogitsProcessor(1.1),
292
+ ])
293
+
294
+ with torch.no_grad():
295
+ out = model.generate(
296
+ **ids,
297
+ max_new_tokens=64,
298
+ do_sample=False,
299
+ logits_processor=processors,
300
+ eos_token_id=self.tok.eos_token_id,
301
+ pad_token_id=(self.tok.pad_token_id or self.tok.eos_token_id),
302
+ use_cache=True,
303
+ )
304
+ txt = self.tok.decode(out[0], skip_special_tokens=True)
305
+ completion = txt[len(prompt):].strip() if txt.startswith(prompt) else txt
306
+ print(f"\n[SMOKE] {prompt} → {completion}\n")
307
+ model.train()
308
+
309
+
310
+ # ===========================
311
+ # Callback schedule router_tau / aux_weight
312
+ # ===========================
313
+
314
+ class RouterScheduleCallback(TrainerCallback):
315
+ """
316
+ Linearly schedule router_tau and router_aux_weight between start and end of training.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ tau_start: float,
322
+ tau_end: float,
323
+ aux_start: float,
324
+ aux_end: float,
325
+ total_steps: int,
326
+ tau_freeze_steps: int = 0,
327
+ force_prob: float = 0.0,
328
+ force_warmup_steps: int = 0,
329
+ ):
330
+ self.tau_start = float(tau_start)
331
+ self.tau_end = float(tau_end)
332
+ self.aux_start = float(aux_start)
333
+ self.aux_end = float(aux_end)
334
+ self.total_steps = max(int(total_steps), 1)
335
+ self.tau_freeze_steps = max(int(tau_freeze_steps), 0)
336
+ self.force_prob = float(force_prob)
337
+ self.force_warmup_steps = max(int(force_warmup_steps), 0)
338
+
339
+ def _interp(self, start: float, end: float, step: int, span: int) -> float:
340
+ t = min(max(step, 0), span)
341
+ alpha = t / float(max(span, 1))
342
+ return (1.0 - alpha) * start + alpha * end
343
+
344
+ def on_step_begin(self, args, state, control, **kwargs):
345
+ model = kwargs.get("model", None)
346
+ if model is None:
347
+ return
348
+ step = state.global_step
349
+ # Tau: keep frozen for tau_freeze_steps, then interpolate over the remaining span
350
+ if step < self.tau_freeze_steps:
351
+ new_tau = self.tau_start
352
+ else:
353
+ rem_step = step - self.tau_freeze_steps
354
+ rem_span = max(self.total_steps - self.tau_freeze_steps, 1)
355
+ new_tau = self._interp(self.tau_start, self.tau_end, rem_step, rem_span)
356
+
357
+ # Aux always interpolates across total training steps
358
+ new_aux = self._interp(self.aux_start, self.aux_end, step, self.total_steps)
359
+
360
+ # update global config
361
+ if hasattr(model, "config"):
362
+ model.config.router_tau = new_tau
363
+ model.config.router_aux_weight = new_aux
364
+
365
+ # update all block.mlp (PolymorphicMLP must use router_tau in forward)
366
+ for block in getattr(model, "blocks", []):
367
+ if hasattr(block, "mlp"):
368
+ # default: no forcing unless scheduled below
369
+ block.mlp.router_tau = new_tau
370
+ block.mlp.force_func = -1
371
+
372
+ # During early warmup, occasionally force a single branch so all get gradients
373
+ if step < self.force_warmup_steps and self.force_prob > 0.0:
374
+ if random.random() < self.force_prob:
375
+ for block in getattr(model, "blocks", []):
376
+ if hasattr(block, "mlp") and hasattr(block.mlp, "num_funcs"):
377
+ k = block.mlp.num_funcs
378
+ block.mlp.force_func = random.randint(0, max(k - 1, 0))
379
+
380
+ if step % 1000 == 0:
381
+ print(
382
+ f"[router-sched] step={step} tau={new_tau:.4f} aux_w={new_aux:.5f} "
383
+ f"freeze<= {self.tau_freeze_steps} force_p={self.force_prob:.3f} warmup<= {self.force_warmup_steps}"
384
+ )
385
+
386
+
387
+ # ===========================
388
+ # Custom Trainer with rep_loss
389
+ # ===========================
390
+
391
+ class VeronicaTrainer(Trainer):
392
+ def __init__(self, *args, label_smoothing: float = 0.0, rep_alpha: float = 0.0, **kwargs):
393
+ super().__init__(*args, **kwargs)
394
+ self.label_smoothing = float(label_smoothing)
395
+ self.rep_alpha = float(rep_alpha)
396
+
397
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
398
+ labels = inputs.get("labels")
399
+ if labels is None:
400
+ raise ValueError("compute_loss called without labels")
401
+ model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
402
+
403
+ outputs = model(**model_inputs)
404
+ logits = outputs.logits # [B, T, V]
405
+
406
+ ignore_index = -100
407
+ # SHIFT: predict x_{t+1}
408
+ shift_logits = logits[:, :-1, :].contiguous()
409
+ shift_labels = labels[:, 1:].contiguous()
410
+
411
+ valid_mask = (shift_labels != ignore_index)
412
+ safe_labels = shift_labels.clone()
413
+ safe_labels[~valid_mask] = 0
414
+
415
+ log_probs = F.log_softmax(shift_logits, dim=-1) # [B, T-1, V]
416
+ nll_full = -log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1)
417
+ nll_loss = nll_full[valid_mask].mean()
418
+
419
+ if self.label_smoothing > 0.0:
420
+ smooth_full = -log_probs.mean(dim=-1)
421
+ smooth_loss = smooth_full[valid_mask].mean()
422
+ ce_loss = (1.0 - self.label_smoothing) * nll_loss + self.label_smoothing * smooth_loss
423
+ else:
424
+ ce_loss = nll_loss
425
+
426
+ total_loss = ce_loss
427
+
428
+ # rep_loss on x_{t+1} when x_{t+1} == x_t
429
+ if self.rep_alpha > 0.0:
430
+ labels_prev = labels[:, :-1] # x_t
431
+ labels_next = shift_labels # x_{t+1}
432
+ valid_prev = (labels_prev != ignore_index)
433
+ same_mask = valid_prev & valid_mask & (labels_prev == labels_next)
434
+ if same_mask.any():
435
+ rep_logp = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1)
436
+ rep_p = rep_logp[same_mask].exp()
437
+ total_loss = total_loss + self.rep_alpha * rep_p.mean()
438
+
439
+ # aux_loss del router: SUBTRACT to MAXIMIZE entropy (prevent collapse)
440
+ aux_loss = getattr(model, "_last_router_aux", None)
441
+ if aux_loss is not None and hasattr(model, "config"):
442
+ aux_w = float(getattr(model.config, "router_aux_weight", 0.0))
443
+ if aux_w > 0:
444
+ if not torch.is_tensor(aux_loss):
445
+ aux_loss = torch.as_tensor(aux_loss, device=logits.device, dtype=logits.dtype)
446
+ # Subtract aux (entropy) so that minimizing loss => maximize entropy => soft router
447
+ total_loss = total_loss - aux_w * aux_loss.clamp_min(0.0)
448
+
449
+ return (total_loss, outputs) if return_outputs else total_loss
450
+
451
+
452
+ # ===========================
453
+ # Main
454
+ # ===========================
455
+
456
+ def main():
457
+ parser = argparse.ArgumentParser()
458
+ parser.add_argument("--config", type=str, required=True)
459
+ parser.add_argument("--dataset_paths", type=str, required=True)
460
+ parser.add_argument("--output_dir", type=str, required=True, default="veronica-polymorphic/runs/veronica-pretrain")
461
+
462
+ parser.add_argument(
463
+ "--tokenizer_candidates",
464
+ type=str,
465
+ nargs="*",
466
+ default=["veronica-polymorphic/tokenizer", "gpt2"],
467
+ )
468
+ parser.add_argument(
469
+ "--tokenizer_out",
470
+ type=str,
471
+ default="veronica-polymorphic/tokenizer_vmix",
472
+ )
473
+
474
+ parser.add_argument("--per_device_train_batch_size", type=int, default=4)
475
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
476
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
477
+ parser.add_argument("--max_steps", type=int, default=60000)
478
+ parser.add_argument("--learning_rate", type=float, default=2e-4)
479
+ parser.add_argument("--warmup_ratio", type=float, default=0.02)
480
+ parser.add_argument("--weight_decay", type=float, default=0.1)
481
+ parser.add_argument("--eval_steps", type=int, default=1000)
482
+ parser.add_argument("--save_steps", type=int, default=1000)
483
+ parser.add_argument("--logging_steps", type=int, default=100)
484
+ parser.add_argument("--label_smoothing", type=float, default=0.01)
485
+ parser.add_argument("--rep_alpha", type=float, default=0.0)
486
+ parser.add_argument("--mask_degenerate_runs", action="store_true")
487
+ parser.add_argument("--seed", type=int, default=42)
488
+
489
+ parser.add_argument(
490
+ "--resume_from",
491
+ type=str,
492
+ default=None,
493
+ help="Checkpoint to resume from (e.g., .../checkpoint-22000)",
494
+ )
495
+
496
+ parser.add_argument(
497
+ "--max_seq_len",
498
+ type=int,
499
+ default=None,
500
+ help="Maximum window length (e.g., 512, 1024, 2048). If None, uses the full dataset sequence.",
501
+ )
502
+
503
+ # Schedule router
504
+ parser.add_argument("--router_tau_start", type=float, default=1.6)
505
+ parser.add_argument("--router_tau_end", type=float, default=1.1)
506
+ parser.add_argument("--router_aux_start", type=float, default=0.005)
507
+ parser.add_argument("--router_aux_end", type=float, default=0.012)
508
+ parser.add_argument("--router_tau_freeze_steps", type=int, default=4000,
509
+ help="Keep tau constant for the first N steps to avoid early specialization.")
510
+ parser.add_argument("--router_force_prob", type=float, default=0.05,
511
+ help="Per-step probability to force a single branch during warmup.")
512
+ parser.add_argument("--router_force_warmup_steps", type=int, default=3000,
513
+ help="Apply random branch forcing only within these initial steps.")
514
+
515
+ args = parser.parse_args()
516
+
517
+ # Tokenizer
518
+ tok = build_tokenizer(args.tokenizer_candidates, args.tokenizer_out)
519
+
520
+ # Config & Model
521
+ cfg = load_cfg_with_vocab(args.config, tok)
522
+ cfg.router_tau = args.router_tau_start
523
+ cfg.router_aux_weight = args.router_aux_start
524
+
525
+ model = init_model_from_config(cfg, tok)
526
+
527
+ # Diagnostics: verify model forward loss
528
+ model.eval()
529
+ with torch.no_grad():
530
+ dummy = torch.randint(0, model.config.vocab_size, (1, 32), device=model.device)
531
+ out = model(input_ids=dummy, labels=dummy)
532
+ loss_model = out.loss.item()
533
+
534
+ logits = out.logits # [1, 32, V]
535
+ shift_logits = logits[:, :-1, :].contiguous()
536
+ shift_labels = dummy[:, 1:].contiguous()
537
+ loss_manual = F.cross_entropy(
538
+ shift_logits.view(-1, shift_logits.size(-1)),
539
+ shift_labels.view(-1)
540
+ ).item()
541
+
542
+ print(f"[diag] loss_model_forward={loss_model:.4f} loss_manual_shift={loss_manual:.4f}")
543
+ model.train()
544
+
545
+ # Dataset
546
+ train_ds, val_ds = load_mix_dataset(args.dataset_paths)
547
+ collator = CausalCollator(
548
+ tokenizer=tok,
549
+ mask_runs=args.mask_degenerate_runs,
550
+ max_seq_len=args.max_seq_len,
551
+ )
552
+
553
+ # Resume
554
+ resume_ckpt = args.resume_from or find_latest_checkpoint(args.output_dir)
555
+ if resume_ckpt:
556
+ print(f"🟢 Resuming from: {resume_ckpt}")
557
+ else:
558
+ print("⚪ No checkpoint: training from scratch.")
559
+
560
+ use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
561
+
562
+ train_args = TrainingArguments(
563
+ output_dir=args.output_dir,
564
+ run_name=os.path.basename(args.output_dir.rstrip("/")),
565
+ num_train_epochs=1_000, # guidato da max_steps
566
+ max_steps=args.max_steps,
567
+ per_device_train_batch_size=args.per_device_train_batch_size,
568
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
569
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
570
+ learning_rate=args.learning_rate,
571
+ warmup_ratio=args.warmup_ratio,
572
+ weight_decay=args.weight_decay,
573
+ lr_scheduler_type="cosine",
574
+ logging_steps=args.logging_steps,
575
+ eval_steps=args.eval_steps,
576
+ save_steps=args.save_steps,
577
+ eval_strategy="steps", # ✅
578
+ save_total_limit=5,
579
+ bf16=use_bf16,
580
+ fp16=(torch.cuda.is_available() and not use_bf16),
581
+ gradient_checkpointing=True,
582
+ report_to=["tensorboard"],
583
+ dataloader_num_workers=2,
584
+ seed=args.seed,
585
+ label_smoothing_factor=0.0, # smoothing gestito in compute_loss custom
586
+ max_grad_norm=1.0,
587
+ save_safetensors=False,
588
+ )
589
+
590
+ callbacks: List[TrainerCallback] = [
591
+ RouterAndSmokeCallback(tok),
592
+ RouterScheduleCallback(
593
+ tau_start=args.router_tau_start,
594
+ tau_end=args.router_tau_end,
595
+ aux_start=args.router_aux_start,
596
+ aux_end=args.router_aux_end,
597
+ total_steps=args.max_steps,
598
+ tau_freeze_steps=args.router_tau_freeze_steps,
599
+ force_prob=args.router_force_prob,
600
+ force_warmup_steps=args.router_force_warmup_steps,
601
+ ),
602
+ ]
603
+
604
+ trainer = VeronicaTrainer(
605
+ model=model,
606
+ args=train_args,
607
+ train_dataset=train_ds,
608
+ eval_dataset=val_ds,
609
+ tokenizer=tok, # ✅ al posto di processing_class
610
+ data_collator=collator,
611
+ callbacks=callbacks,
612
+ label_smoothing=args.label_smoothing,
613
+ rep_alpha=args.rep_alpha,
614
+ )
615
+
616
+ # Sanity check: vocab/emb/head
617
+ effective_vocab = len(tok)
618
+ emb = model.get_input_embeddings().weight
619
+ head = model.lm_head.weight
620
+ assert emb.shape[0] == effective_vocab == head.shape[0], "Mismatch vocab/emb/lm_head"
621
+
622
+ # Train
623
+ trainer.train(resume_from_checkpoint=resume_ckpt)
624
+ trainer.save_state()
625
+ trainer.save_model(args.output_dir)
626
+ tok.save_pretrained(args.output_dir)
627
+ with open(os.path.join(args.output_dir, "config.final.json"), "w", encoding="utf-8") as f:
628
+ json.dump(model.config.to_dict(), f, indent=2)
629
+ print("✅ Pretraining completed/saved.")
630
+
631
+
632
+ if __name__ == "__main__":
633
+ main()