narySt commited on
Commit
95afc71
·
verified ·
1 Parent(s): 8a3033b

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  train_hnet_with_docstring_18_04/wandb/run-20260417_085757-sa79g3yl/run-sa79g3yl.wandb filter=lfs diff=lfs merge=lfs -text
37
  wandb/run-20260418_121916-2mk39j3k/run-2mk39j3k.wandb filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  train_hnet_with_docstring_18_04/wandb/run-20260417_085757-sa79g3yl/run-sa79g3yl.wandb filter=lfs diff=lfs merge=lfs -text
37
  wandb/run-20260418_121916-2mk39j3k/run-2mk39j3k.wandb filter=lfs diff=lfs merge=lfs -text
38
+ pythia1b_v5_04_21/wandb/run-20260421_202839-8ing6xdi/run-8ing6xdi.wandb filter=lfs diff=lfs merge=lfs -text
pythia1b_v5_04_21/checkpoints/checkpoint_latest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b58e3452e25fa6a1b168abc2eeb93bfd178bd0d36dedbb37170f173dd7a163d
3
+ size 6070947158
pythia1b_v5_04_21/checkpoints/checkpoint_step_10591.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b82966ae12c1e00b13902b6d602e024a5131f8f51b2de7fcfdd95b689050729
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_12000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a35b05f715bb652fa11b55de09cf91de2bb1cfd86ebd9c6dedb708608a0d412c
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_15000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44f0611f7af451e017f865eeaa0281b81ed3a02071e89003d9b5b6e95a1b19d5
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_18000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cfdf9707733ebf03847f06688369930cc98a11842515a343355432fd5f1b521
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_21000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4ac79d33bc5ef5a51922f227e04f969c5b69356dc854730897d98bb10d9e731
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_21182.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:505dcf26777c122f00bd156aed99f953998a4ec32972fbd3d091ebbe26ee0be4
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_24000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75521e0fbe7d16afd7a96599a0cbb6ed91ab8f827cf4b8ebe194e89959ec16a1
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_27000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85a288ed13b41ef7c4d6b185eae644005f8e4f0dbcb930a8ad620668ffcd9c37
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_3000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c372dfd8cc37b45a6747878c0f1ca532c2775264c3814cc5e77edc85b6cf1553
3
+ size 6070949586
pythia1b_v5_04_21/checkpoints/checkpoint_step_30000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5514bd8b466d5dfbe060bd1030688b3e8631a3a7cc74341f797fcb81bb07202
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_31773.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23a2f1b814172d824cc94fe4eb671efb85eddc0d701dc770bbd2510a862dc9f2
3
+ size 6070950374
pythia1b_v5_04_21/checkpoints/checkpoint_step_6000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99b4d33f41190374d0a57e12b5715a87ca53f4693512fe0a8d518ff852ccc7e7
3
+ size 6070949586
pythia1b_v5_04_21/checkpoints/checkpoint_step_9000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f35f51cf37687952cc2753d63e1d0ccfd173b4f43a345440f98613c07c39ac5
3
+ size 6070949586
pythia1b_v5_04_21/model_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:777b8524d357a844e33c810453f9bca93a43ee31f69a72619ba2b24f61487d48
3
+ size 2023640386
pythia1b_v5_04_21/model_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:738b4bf2c5af646ac8c184cc78e3a2359d1a6689362cbca0b53ab0cd7d2fea53
3
+ size 2023640586
pythia1b_v5_04_21/wandb/run-20260421_202839-8ing6xdi/files/code/code_completion_exp/train_pythia/train.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Pipeline для Pythia (decoder-only transformer) на задаче Code Completion.
3
+
4
+ Конфигурация через Hydra + OmegaConf, логирование в Trackio.
5
+ Поддержка DDP через Accelerate для multi-GPU тренировки.
6
+
7
+ Использование:
8
+ # Базовый запуск (single GPU)
9
+ python train.py
10
+
11
+ # Multi-GPU с Accelerate
12
+ accelerate launch train.py
13
+
14
+ # Multi-GPU с указанием количества GPU
15
+ accelerate launch --num_processes=4 train.py
16
+
17
+ # Переопределение параметров через CLI
18
+ python train.py training.lr=1e-4 training.epochs=5
19
+
20
+ # Выбор другого конфига модели
21
+ python train.py model=pythia_160m
22
+
23
+ # Multirun (sweep)
24
+ python train.py --multirun training.lr=1e-4,3e-4,1e-3
25
+
26
+ # Без логирования
27
+ python train.py tracking.enabled=false
28
+ """
29
+
30
+ import os
31
+ import math
32
+ import time
33
+ from pathlib import Path
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from torch.utils.data import DataLoader
39
+ from datasets import load_from_disk
40
+
41
+ import hydra
42
+ from hydra.core.hydra_config import HydraConfig
43
+ from omegaconf import DictConfig, OmegaConf
44
+ from transformers import (
45
+ AutoTokenizer,
46
+ AutoModelForCausalLM,
47
+ AutoConfig,
48
+ PreTrainedTokenizerBase,
49
+ )
50
+ from accelerate import Accelerator
51
+ from accelerate.utils import set_seed as accelerate_set_seed
52
+
53
+ # Ensure repo root is on sys.path (needed when running from subdirectory)
54
+ import sys
55
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
56
+
57
+ # Shared training library
58
+ from training_lib.utils import AverageMeter, log_message
59
+ from training_lib.checkpointing import save_checkpoint, load_checkpoint
60
+ from training_lib.schedulers import get_lr_scheduler
61
+ from training_lib.tracking import init_tracking, log_metrics, finish_tracking
62
+ from training_lib.validation import run_validation
63
+
64
+
65
+ # ============================================================================
66
+ # ДАННЫЕ
67
+ # ============================================================================
68
+
69
+
70
+ class CodeCompletionCollator:
71
+ """Collate function для батчирования примеров code completion."""
72
+
73
+ def __init__(
74
+ self,
75
+ tokenizer: PreTrainedTokenizerBase,
76
+ max_context_len: int = 1024,
77
+ max_target_len: int = 256,
78
+ ):
79
+ self.tokenizer = tokenizer
80
+ self.max_context_len = max_context_len
81
+ self.max_target_len = max_target_len
82
+ self.pad_token_id = tokenizer.pad_token_id
83
+
84
+ def __call__(self, batch: list[dict]) -> dict:
85
+ contexts = [item["context"] for item in batch]
86
+ targets = [item["target"] for item in batch]
87
+
88
+ encoded_contexts = self.tokenizer(
89
+ contexts,
90
+ add_special_tokens=True,
91
+ truncation=True,
92
+ max_length=self.max_context_len,
93
+ return_tensors=None,
94
+ )
95
+ encoded_targets = self.tokenizer(
96
+ targets,
97
+ add_special_tokens=False,
98
+ truncation=True,
99
+ max_length=self.max_target_len,
100
+ return_tensors=None,
101
+ )
102
+
103
+ input_ids_list = []
104
+ context_lengths = []
105
+
106
+ for ctx_ids, tgt_ids in zip(
107
+ encoded_contexts["input_ids"], encoded_targets["input_ids"]
108
+ ):
109
+ tgt_ids = tgt_ids + [self.tokenizer.eos_token_id]
110
+ context_lengths.append(len(ctx_ids))
111
+ input_ids_list.append(ctx_ids + tgt_ids)
112
+
113
+ max_len = max(len(ids) for ids in input_ids_list)
114
+
115
+ padded_input_ids = []
116
+ attention_mask = []
117
+
118
+ for ids in input_ids_list:
119
+ padding_len = max_len - len(ids)
120
+ padded_input_ids.append(ids + [self.pad_token_id] * padding_len)
121
+ attention_mask.append([1] * len(ids) + [0] * padding_len)
122
+
123
+ return {
124
+ "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
125
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
126
+ "context_lengths": torch.tensor(context_lengths, dtype=torch.long),
127
+ }
128
+
129
+
130
+ def create_dataloaders(
131
+ cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
132
+ ) -> dict[str, DataLoader]:
133
+ """Создание DataLoader'ов для train и validation."""
134
+ dataset_dict = load_from_disk(cfg.data.path)
135
+
136
+ collator = CodeCompletionCollator(
137
+ tokenizer=tokenizer,
138
+ max_context_len=cfg.data.max_context_len,
139
+ max_target_len=cfg.data.max_target_len,
140
+ )
141
+
142
+ dataloaders = {}
143
+
144
+ if "train" in dataset_dict:
145
+ dataloaders["train"] = DataLoader(
146
+ dataset_dict["train"],
147
+ batch_size=cfg.training.batch_size,
148
+ shuffle=True,
149
+ collate_fn=collator,
150
+ num_workers=cfg.data.num_workers,
151
+ pin_memory=cfg.data.pin_memory,
152
+ )
153
+
154
+ if "validation" in dataset_dict:
155
+ eval_batch_size = cfg.training.get("eval_batch_size", cfg.training.batch_size)
156
+ dataloaders["validation"] = DataLoader(
157
+ dataset_dict["validation"],
158
+ batch_size=eval_batch_size,
159
+ shuffle=False,
160
+ collate_fn=collator,
161
+ num_workers=cfg.data.num_workers,
162
+ pin_memory=cfg.data.pin_memory,
163
+ )
164
+
165
+ return dataloaders
166
+
167
+
168
+
169
+
170
+ # ============================================================================
171
+ # LOSS ФУНКЦИИ
172
+ # ============================================================================
173
+
174
+
175
+ def compute_loss(
176
+ logits: torch.Tensor,
177
+ input_ids: torch.Tensor,
178
+ context_lengths: torch.Tensor,
179
+ attention_mask: torch.Tensor,
180
+ ) -> dict:
181
+ """Вычисление loss для авторегрессионной модели."""
182
+ batch_size, seq_len, vocab_size = logits.shape
183
+
184
+ shift_logits = logits[:, :-1, :].contiguous()
185
+ shift_labels = input_ids[:, 1:].contiguous()
186
+ shift_mask = attention_mask[:, 1:].contiguous()
187
+
188
+ target_mask = torch.zeros_like(shift_labels, dtype=torch.bool)
189
+ for i in range(batch_size):
190
+ ctx_len = context_lengths[i].item()
191
+ target_mask[i, ctx_len - 1 :] = True
192
+
193
+ final_mask = target_mask & shift_mask.bool()
194
+
195
+ if final_mask.sum() > 0:
196
+ loss = F.cross_entropy(
197
+ shift_logits[final_mask], shift_labels[final_mask], reduction="mean"
198
+ )
199
+ else:
200
+ loss = torch.tensor(0.0, device=logits.device)
201
+
202
+ return {"loss": loss}
203
+
204
+
205
+ def _pythia_forward_loss(
206
+ model: nn.Module,
207
+ batch: dict,
208
+ cfg: DictConfig,
209
+ accelerator: Accelerator,
210
+ ) -> dict:
211
+ """Forward + loss for a plain HF causal LM (attention_mask= kwarg, .logits)."""
212
+ input_ids = batch["input_ids"]
213
+ attention_mask = batch["attention_mask"]
214
+ context_lengths = batch["context_lengths"]
215
+ output = model(input_ids, attention_mask=attention_mask)
216
+ return compute_loss(output.logits, input_ids, context_lengths, attention_mask)
217
+
218
+
219
+ # ============================================================================
220
+ # PARAMETER GROUPING
221
+ # ============================================================================
222
+
223
+
224
+ def group_params(model: nn.Module, weight_decay: float) -> list[dict]:
225
+ """Группировка параметров для optimizer."""
226
+ decay_params = []
227
+ no_decay_params = []
228
+
229
+ for name, param in model.named_parameters():
230
+ if not param.requires_grad:
231
+ continue
232
+
233
+ if "bias" in name or "LayerNorm" in name or "layernorm" in name:
234
+ no_decay_params.append(param)
235
+ else:
236
+ decay_params.append(param)
237
+
238
+ return [
239
+ {"params": decay_params, "weight_decay": weight_decay},
240
+ {"params": no_decay_params, "weight_decay": 0.0},
241
+ ]
242
+
243
+
244
+
245
+
246
+ # ============================================================================
247
+ # TRAINING LOOP
248
+ # ============================================================================
249
+
250
+
251
+ def train_epoch(
252
+ model: nn.Module,
253
+ dataloader: DataLoader,
254
+ optimizer: torch.optim.Optimizer,
255
+ scheduler,
256
+ cfg: DictConfig,
257
+ epoch: int,
258
+ global_step: int,
259
+ accelerator: Accelerator,
260
+ val_dataloader: DataLoader | None = None,
261
+ best_val_loss: float = float("inf"),
262
+ ) -> tuple[int, float]:
263
+ """Один epoch тренировки. Возвращает (global_step, best_val_loss)."""
264
+ model.train()
265
+
266
+ loss_meter = AverageMeter()
267
+
268
+ optimizer.zero_grad()
269
+ accumulated_loss = 0.0
270
+ accumulated_steps = 0
271
+
272
+ epoch_start_time = time.time()
273
+ step_start_time = time.time()
274
+
275
+ for batch_idx, batch in enumerate(dataloader):
276
+ input_ids = batch["input_ids"]
277
+ attention_mask = batch["attention_mask"]
278
+ context_lengths = batch["context_lengths"]
279
+
280
+ with accelerator.autocast():
281
+ output = model(input_ids, attention_mask=attention_mask)
282
+ logits = output.logits
283
+ loss_dict = compute_loss(
284
+ logits, input_ids, context_lengths, attention_mask
285
+ )
286
+
287
+ loss = loss_dict["loss"] / cfg.training.gradient_accumulation_steps
288
+ accelerator.backward(loss)
289
+
290
+ accumulated_loss += loss_dict["loss"].item()
291
+ accumulated_steps += 1
292
+
293
+ if accumulated_steps == cfg.training.gradient_accumulation_steps:
294
+ if cfg.training.max_grad_norm > 0:
295
+ accelerator.clip_grad_norm_(
296
+ model.parameters(), cfg.training.max_grad_norm
297
+ )
298
+
299
+ optimizer.step()
300
+ scheduler.step()
301
+ optimizer.zero_grad()
302
+
303
+ avg_loss = accumulated_loss / cfg.training.gradient_accumulation_steps
304
+ loss_meter.update(avg_loss)
305
+
306
+ global_step += 1
307
+
308
+ if global_step % cfg.logging.log_interval == 0:
309
+ step_time = time.time() - step_start_time
310
+ current_lr = scheduler.get_last_lr()[0]
311
+
312
+ metrics = {
313
+ "train/loss": loss_meter.val,
314
+ "train/loss_avg": loss_meter.avg,
315
+ "train/lr": current_lr,
316
+ "train/epoch": epoch,
317
+ "train/step_time": step_time / cfg.logging.log_interval,
318
+ }
319
+
320
+ log_metrics(metrics, step=global_step)
321
+
322
+ log_message(
323
+ f"Epoch {epoch} | Step {global_step} | "
324
+ f"Loss: {loss_meter.avg:.4f} | "
325
+ f"LR: {current_lr:.2e}",
326
+ cfg,
327
+ accelerator,
328
+ )
329
+
330
+ step_start_time = time.time()
331
+
332
+ if (
333
+ cfg.logging.save_interval > 0
334
+ and global_step % cfg.logging.save_interval == 0
335
+ ):
336
+ save_checkpoint(
337
+ model, optimizer, scheduler, global_step, epoch, cfg, accelerator
338
+ )
339
+
340
+ eval_interval = cfg.logging.get("eval_interval", 0)
341
+ if (
342
+ eval_interval > 0
343
+ and val_dataloader is not None
344
+ and global_step % eval_interval == 0
345
+ ):
346
+ val_metrics = run_validation(
347
+ model=model,
348
+ dataloader=val_dataloader,
349
+ cfg=cfg,
350
+ global_step=global_step,
351
+ accelerator=accelerator,
352
+ forward_loss_fn=_pythia_forward_loss,
353
+ )
354
+
355
+ if val_metrics["val/loss"] < best_val_loss:
356
+ best_val_loss = val_metrics["val/loss"]
357
+ if accelerator.is_main_process:
358
+ best_model_path = Path(cfg.paths.output_dir) / "model_best.pt"
359
+ unwrapped_model = accelerator.unwrap_model(model)
360
+ torch.save(unwrapped_model.state_dict(), best_model_path)
361
+ log_message(
362
+ f"New best model saved! Val loss: {best_val_loss:.4f}",
363
+ cfg,
364
+ accelerator
365
+ )
366
+
367
+ log_metrics(
368
+ {
369
+ "best/val_loss": best_val_loss,
370
+ "best/val_perplexity": val_metrics["val/perplexity"],
371
+ "best/step": global_step,
372
+ },
373
+ step=global_step,
374
+ )
375
+
376
+ model.train()
377
+
378
+ accumulated_loss = 0.0
379
+ accumulated_steps = 0
380
+
381
+ epoch_time = time.time() - epoch_start_time
382
+
383
+ log_message(
384
+ f"Epoch {epoch} completed in {epoch_time:.2f}s | "
385
+ f"Loss: {loss_meter.avg:.4f}",
386
+ cfg,
387
+ accelerator,
388
+ )
389
+
390
+ log_metrics({
391
+ "epoch/loss": loss_meter.avg,
392
+ "epoch/time": epoch_time,
393
+ })
394
+
395
+ return global_step, best_val_loss
396
+
397
+
398
+ # ============================================================================
399
+ # MAIN
400
+ # ============================================================================
401
+
402
+
403
+ @hydra.main(version_base=None, config_path="configs", config_name="config")
404
+ def main(cfg: DictConfig):
405
+ """Главная функция тренировки с поддержкой DDP через Accelerate."""
406
+
407
+ # === Performance: Enable TF32 for faster matmuls on Ampere+ GPUs ===
408
+ torch.set_float32_matmul_precision('high')
409
+
410
+ # === Accelerator Setup ===
411
+ mixed_precision = "bf16" if cfg.training.use_amp else "no"
412
+
413
+ accelerator = Accelerator(
414
+ mixed_precision=mixed_precision,
415
+ gradient_accumulation_steps=cfg.training.gradient_accumulation_steps,
416
+ )
417
+
418
+ # === Setup ===
419
+ accelerate_set_seed(cfg.seed)
420
+
421
+ if cfg.paths.output_dir is None:
422
+ cfg.paths.output_dir = HydraConfig.get().runtime.output_dir
423
+
424
+ OmegaConf.resolve(cfg)
425
+
426
+ log_message(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'not set')}", cfg, accelerator)
427
+ log_message(f"Number of processes: {accelerator.num_processes}", cfg, accelerator)
428
+ log_message(f"Process index: {accelerator.process_index}", cfg, accelerator)
429
+ log_message(f"Mixed precision: {mixed_precision}", cfg, accelerator)
430
+
431
+ log_message("=" * 60, cfg, accelerator)
432
+ log_message("Pythia Training Pipeline (Hydra + Trackio + Accelerate)", cfg, accelerator)
433
+ log_message("=" * 60, cfg, accelerator)
434
+ log_message(f"Config:\n{OmegaConf.to_yaml(cfg)}", cfg, accelerator)
435
+
436
+ # === Trackio Init ===
437
+ init_tracking(cfg, accelerator)
438
+
439
+ # === Tokenizer ===
440
+ log_message("Initializing tokenizer...", cfg, accelerator)
441
+ tokenizer = AutoTokenizer.from_pretrained(cfg.model.name)
442
+
443
+ if tokenizer.pad_token is None:
444
+ tokenizer.pad_token = tokenizer.eos_token
445
+ tokenizer.pad_token_id = tokenizer.eos_token_id
446
+
447
+ # === Model ===
448
+ log_message("Loading model...", cfg, accelerator)
449
+
450
+ # Flash Attention 2
451
+ torch_dtype = torch.bfloat16 if cfg.training.use_amp else torch.float32
452
+
453
+ if cfg.model.checkpoint_path:
454
+ model = AutoModelForCausalLM.from_pretrained(
455
+ cfg.model.name,
456
+ attn_implementation="flash_attention_2",
457
+ torch_dtype=torch_dtype,
458
+ )
459
+ checkpoint = torch.load(cfg.model.checkpoint_path, map_location="cpu")
460
+ model.load_state_dict(checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint)
461
+ log_message(f"Loaded checkpoint: {cfg.model.checkpoint_path}", cfg, accelerator)
462
+ elif cfg.model.from_scratch:
463
+ config = AutoConfig.from_pretrained(cfg.model.name)
464
+ config._attn_implementation = "flash_attention_2"
465
+ model = AutoModelForCausalLM.from_config(config, torch_dtype=torch_dtype)
466
+ log_message(f"Initialized from scratch: {cfg.model.name}", cfg, accelerator)
467
+ else:
468
+ model = AutoModelForCausalLM.from_pretrained(
469
+ cfg.model.name,
470
+ attn_implementation="flash_attention_2",
471
+ torch_dtype=torch_dtype,
472
+ )
473
+ log_message(f"Loaded pretrained: {cfg.model.name}", cfg, accelerator)
474
+
475
+ model.train()
476
+
477
+ # Log model info
478
+ total_params = sum(p.numel() for p in model.parameters())
479
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
480
+ log_message(f"Total params: {total_params:,}", cfg, accelerator)
481
+ log_message(f"Trainable params: {trainable_params:,}", cfg, accelerator)
482
+
483
+ # === Data ===
484
+ log_message("Creating dataloaders...", cfg, accelerator)
485
+ dataloaders = create_dataloaders(cfg, tokenizer)
486
+
487
+ train_dataloader = dataloaders["train"]
488
+ val_dataloader = dataloaders.get("validation", None)
489
+
490
+ log_message(f"Train dataset size: {len(train_dataloader.dataset)}", cfg, accelerator)
491
+ log_message(f"Train batches per epoch (before DDP split): {len(train_dataloader)}", cfg, accelerator)
492
+
493
+ if val_dataloader:
494
+ log_message(f"Validation dataset size: {len(val_dataloader.dataset)}", cfg, accelerator)
495
+ log_message(f"Validation batches: {len(val_dataloader)}", cfg, accelerator)
496
+ else:
497
+ log_message("No validation dataset found", cfg, accelerator)
498
+
499
+ # === Optimizer ===
500
+ log_message("Creating optimizer...", cfg, accelerator)
501
+ param_groups = group_params(model, cfg.training.weight_decay)
502
+
503
+ optimizer = torch.optim.AdamW(
504
+ param_groups,
505
+ lr=cfg.training.lr,
506
+ betas=tuple(cfg.training.betas),
507
+ eps=cfg.training.eps,
508
+ )
509
+
510
+ # === Scheduler ===
511
+ steps_per_epoch = math.ceil(
512
+ len(train_dataloader) / accelerator.num_processes
513
+ )
514
+ total_steps = (
515
+ cfg.training.epochs
516
+ * steps_per_epoch
517
+ // cfg.training.gradient_accumulation_steps
518
+ )
519
+ scheduler = get_lr_scheduler(optimizer, cfg, total_steps)
520
+
521
+ log_message(
522
+ f"Total steps: {total_steps}, Steps per epoch: {steps_per_epoch}",
523
+ cfg,
524
+ accelerator
525
+ )
526
+
527
+ # === Accelerate Prepare ===
528
+ log_message("Preparing model, optimizer, and dataloaders with Accelerate...", cfg, accelerator)
529
+
530
+ if val_dataloader is not None:
531
+ model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
532
+ model, optimizer, train_dataloader, val_dataloader, scheduler
533
+ )
534
+ else:
535
+ model, optimizer, train_dataloader, scheduler = accelerator.prepare(
536
+ model, optimizer, train_dataloader, scheduler
537
+ )
538
+
539
+ log_message(f"Train batches per epoch (after DDP split): {len(train_dataloader)}", cfg, accelerator)
540
+
541
+ # === Resume ===
542
+ global_step = 0
543
+ start_epoch = 1
544
+
545
+ if cfg.training.resume and cfg.training.resume_checkpoint:
546
+ global_step, start_epoch = load_checkpoint(
547
+ model, optimizer, scheduler, cfg.training.resume_checkpoint, cfg, accelerator
548
+ )
549
+ start_epoch += 1
550
+
551
+ # === Training Loop ===
552
+ log_message("Starting training...", cfg, accelerator)
553
+
554
+ best_val_loss = float("inf")
555
+
556
+ try:
557
+ for epoch in range(start_epoch, cfg.training.epochs + 1):
558
+ log_message(f"\n{'=' * 60}", cfg, accelerator)
559
+ log_message(f"EPOCH {epoch}/{cfg.training.epochs}", cfg, accelerator)
560
+ log_message(f"{'=' * 60}", cfg, accelerator)
561
+
562
+ global_step, best_val_loss = train_epoch(
563
+ model=model,
564
+ dataloader=train_dataloader,
565
+ optimizer=optimizer,
566
+ scheduler=scheduler,
567
+ cfg=cfg,
568
+ epoch=epoch,
569
+ global_step=global_step,
570
+ accelerator=accelerator,
571
+ val_dataloader=val_dataloader,
572
+ best_val_loss=best_val_loss,
573
+ )
574
+
575
+ if cfg.logging.save_every_epoch:
576
+ save_checkpoint(
577
+ model, optimizer, scheduler, global_step, epoch, cfg, accelerator
578
+ )
579
+
580
+ except KeyboardInterrupt:
581
+ log_message("Training interrupted by user", cfg, accelerator)
582
+ save_checkpoint(model, optimizer, scheduler, global_step, epoch, cfg, accelerator)
583
+
584
+ # === Final Save ===
585
+ log_message("\nTraining completed!", cfg, accelerator)
586
+
587
+ if accelerator.is_main_process:
588
+ final_model_path = Path(cfg.paths.output_dir) / "model_final.pt"
589
+ unwrapped_model = accelerator.unwrap_model(model)
590
+ torch.save(unwrapped_model.state_dict(), final_model_path)
591
+ log_message(f"Final model: {final_model_path}", cfg, accelerator)
592
+
593
+ accelerator.wait_for_everyone()
594
+ finish_tracking()
595
+
596
+
597
+ if __name__ == "__main__":
598
+ main()
pythia1b_v5_04_21/wandb/run-20260421_202839-8ing6xdi/run-8ing6xdi.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:717527847fea27ac89ab840fa450ede0488f79e543e77544bf788b7b7673ba98
3
+ size 5275648