ronnengmail commited on
Commit
452f44f
·
verified ·
1 Parent(s): 198c544

Upload training_scripts/train_multilingual_3b.py with huggingface_hub

Browse files
training_scripts/train_multilingual_3b.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multilingual 3.14B GPT Training (Arabic-Rebalanced Data)
4
+
5
+ Scaled from 1B v2/v3 validated recipe:
6
+ - Architecture: dim=3072, depth=26, heads=24, ~3.14B params
7
+ - Data: training-data-v2 (4.48B tokens, multi-epoch)
8
+ - Schedule: WSD-LINEAR (validated at 1B)
9
+ - LR: 3e-4 (scaled from 1B's 5e-4 via sqrt rule)
10
+ """
11
+
12
+ import os, sys, json, math, time, copy
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
20
+
21
+ # ============ MODEL CONFIG ============
22
+ VOCAB_SIZE = 32000
23
+ DIM = 3072
24
+ DEPTH = 26
25
+ N_HEADS = 24
26
+ MAX_SEQ_LEN = 2048
27
+ ROPE_THETA = 10000
28
+ DROPOUT = 0.05 # Slightly less than 1B (0.1) — larger models need less regularization
29
+
30
+ # ============ TRAINING CONFIG ============
31
+ TOTAL_STEPS = 20000 # ~10.5B tokens (2.3 epochs of 4.48B)
32
+ WARMUP_STEPS = 600 # 3% warmup
33
+ STABLE_END = 14000 # 70% at peak LR
34
+ MIN_LR_RATIO = 0.03
35
+
36
+ BATCH_PER_GPU = 2 # Per-GPU batch size
37
+ GRAD_ACCUM = 16 # With 8 GPUs: 8*4*8 = 256 seqs = 524K tokens/step
38
+ # Total: 20000 * 524288 = 10.49B tokens
39
+
40
+ ADAMW_LR = 3e-4 # Scaled: 5e-4 * sqrt(502M/3142M) ≈ 2e-4, being slightly aggressive
41
+ ADAMW_BETAS = (0.9, 0.98)
42
+ ADAMW_WD = 0.02
43
+ ADAMW_EPS = 1e-8
44
+
45
+ LABEL_SMOOTHING = 0.06
46
+ GRAD_CLIP = 1.0
47
+
48
+ SWA_START_FRAC = 0.40 # Start at step 8000
49
+ SWA_FREQ = 40 # Every 40 steps (scaled from 1B's 20)
50
+
51
+ EVAL_EVERY = 500 # Eval every 500 steps
52
+ SAVE_EVERY = 2000 # Checkpoint every 2000 steps
53
+ LOG_EVERY = 50
54
+
55
+ DATA_DIR = "/tmp/training-data"
56
+ CKPT_DIR = "/tmp/checkpoints"
57
+ LOG_FILE = "/tmp/training.log"
58
+ EVAL_FILE = "/tmp/eval_results.json"
59
+
60
+ S3_BUCKET = "autoresearch-dashboard-196766918360"
61
+ S3_PREFIX = "multilingual-7b"
62
+ VERSION = "3b-v1"
63
+
64
+ # ============ MODEL ============
65
+ class RMSNorm(nn.Module):
66
+ def __init__(self, dim, eps=1e-6):
67
+ super().__init__()
68
+ self.weight = nn.Parameter(torch.ones(dim))
69
+ self.eps = eps
70
+ def forward(self, x):
71
+ return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.weight
72
+
73
+ class SwiGLU(nn.Module):
74
+ def __init__(self, dim, hidden_dim):
75
+ super().__init__()
76
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
77
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
78
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
79
+ def forward(self, x):
80
+ return self.down(F.silu(self.gate(x)) * self.up(x))
81
+
82
+ def apply_rope(x, cos, sin):
83
+ x1, x2 = x[..., ::2], x[..., 1::2]
84
+ return torch.stack((x1*cos - x2*sin, x1*sin + x2*cos), dim=-1).flatten(-2)
85
+
86
+ class Attention(nn.Module):
87
+ def __init__(self, dim, n_heads, dropout=0.0):
88
+ super().__init__()
89
+ self.n_heads = n_heads
90
+ self.head_dim = dim // n_heads
91
+ self.qkv = nn.Linear(dim, 3*dim, bias=False)
92
+ self.proj = nn.Linear(dim, dim, bias=False)
93
+ self.attn_dropout = dropout
94
+ def forward(self, x, cos, sin):
95
+ B, T, C = x.shape
96
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
97
+ q, k, v = qkv[0], qkv[1], qkv[2]
98
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
99
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True,
100
+ dropout_p=self.attn_dropout if self.training else 0.0)
101
+ return self.proj(y.transpose(1, 2).contiguous().view(B, T, C))
102
+
103
+ class Block(nn.Module):
104
+ def __init__(self, dim, n_heads, mlp_dim, dropout=0.0):
105
+ super().__init__()
106
+ self.ln1 = RMSNorm(dim)
107
+ self.attn = Attention(dim, n_heads, dropout)
108
+ self.ln2 = RMSNorm(dim)
109
+ self.mlp = SwiGLU(dim, mlp_dim)
110
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
111
+ def forward(self, x, cos, sin):
112
+ x = x + self.drop(self.attn(self.ln1(x), cos, sin))
113
+ x = x + self.drop(self.mlp(self.ln2(x)))
114
+ return x
115
+
116
+ class GPT(nn.Module):
117
+ def __init__(self, vocab_size=VOCAB_SIZE, dim=DIM, depth=DEPTH, n_heads=N_HEADS,
118
+ max_seq_len=MAX_SEQ_LEN, rope_theta=ROPE_THETA, dropout=DROPOUT):
119
+ super().__init__()
120
+ self.tok_emb = nn.Embedding(vocab_size, dim)
121
+ mlp_dim = ((int(2 * dim * 4 / 3) + 63) // 64) * 64 # = 8192 for dim=3072
122
+ self.blocks = nn.ModuleList([Block(dim, n_heads, mlp_dim, dropout) for _ in range(depth)])
123
+ self.ln_f = RMSNorm(dim)
124
+ self.head = nn.Linear(dim, vocab_size, bias=False)
125
+ self.head.weight = self.tok_emb.weight # weight tying
126
+ hd = dim // n_heads
127
+ freqs = 1.0 / (rope_theta ** (torch.arange(0, hd, 2).float() / hd))
128
+ angles = torch.outer(torch.arange(max_seq_len).float(), freqs)
129
+ self.register_buffer('rope_cos', angles.cos())
130
+ self.register_buffer('rope_sin', angles.sin())
131
+ self.apply(self._init_weights)
132
+
133
+ def _init_weights(self, module):
134
+ if isinstance(module, nn.Linear):
135
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
136
+ if module.bias is not None:
137
+ torch.nn.init.zeros_(module.bias)
138
+ elif isinstance(module, nn.Embedding):
139
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
140
+
141
+ def forward(self, idx):
142
+ B, T = idx.shape
143
+ x = self.tok_emb(idx)
144
+ cos = self.rope_cos[:T][None, None]
145
+ sin = self.rope_sin[:T][None, None]
146
+ for block in self.blocks:
147
+ if self.training:
148
+ x = torch_checkpoint(block, x, cos, sin, use_reentrant=False)
149
+ else:
150
+ x = block(x, cos, sin)
151
+ return self.head(self.ln_f(x))
152
+
153
+ # ============ WSD LINEAR SCHEDULE ============
154
+ def wsd_lr_linear(step, total_steps, warmup_steps, stable_end, min_lr_ratio, base_lr):
155
+ if step < warmup_steps:
156
+ return base_lr * (step + 1) / max(warmup_steps, 1)
157
+ elif step < stable_end:
158
+ return base_lr
159
+ else:
160
+ progress = (step - stable_end) / max(total_steps - stable_end, 1)
161
+ return base_lr * (1.0 - progress * (1.0 - min_lr_ratio))
162
+
163
+ # ============ DATA LOADING ============
164
+ class BinaryDataset:
165
+ def __init__(self, path, seq_len):
166
+ self.data = np.memmap(path, dtype=np.uint16, mode='r')
167
+ self.seq_len = seq_len
168
+ self.n_tokens = len(self.data)
169
+ def get_batch(self, batch_size, device, rng):
170
+ ix = torch.from_numpy(rng.integers(0, self.n_tokens - self.seq_len - 1, size=(batch_size,)))
171
+ x = torch.stack([torch.from_numpy(self.data[i:i+self.seq_len].astype(np.int64)) for i in ix])
172
+ y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.seq_len].astype(np.int64)) for i in ix])
173
+ return x.to(device), y.to(device)
174
+
175
+ def load_val_data(path, seq_len, max_batches=20, batch_size=8):
176
+ data = np.memmap(path, dtype=np.uint16, mode='r')
177
+ n_tokens = len(data)
178
+ batches = []
179
+ stride = seq_len + 1
180
+ all_starts = list(range(0, n_tokens - stride, stride))
181
+ max_samples = max_batches * batch_size
182
+ if len(all_starts) > max_samples:
183
+ step_size = len(all_starts) // max_samples
184
+ all_starts = all_starts[::step_size][:max_samples]
185
+ for i in range(0, len(all_starts), batch_size):
186
+ batch_starts = all_starts[i:i+batch_size]
187
+ if len(batch_starts) < batch_size:
188
+ break
189
+ x = torch.stack([torch.from_numpy(data[s:s+seq_len].astype(np.int64)) for s in batch_starts])
190
+ y = torch.stack([torch.from_numpy(data[s+1:s+1+seq_len].astype(np.int64)) for s in batch_starts])
191
+ batches.append((x, y))
192
+ return batches
193
+
194
+ @torch.no_grad()
195
+ def evaluate(model, val_batches, device):
196
+ model.eval()
197
+ total_loss = 0.0
198
+ total_tokens = 0
199
+ for x, y in val_batches:
200
+ x, y = x.to(device), y.to(device)
201
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
202
+ logits = model(x)
203
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='sum')
204
+ total_loss += loss.item()
205
+ total_tokens += y.numel()
206
+ model.train()
207
+ return (total_loss / total_tokens) / math.log(2) if total_tokens > 0 else float('inf')
208
+
209
+ class Logger:
210
+ def __init__(self, log_file, rank):
211
+ self.rank = rank
212
+ if rank == 0:
213
+ self.f = open(log_file, 'w')
214
+ def log(self, msg):
215
+ if self.rank == 0:
216
+ ts = time.strftime('%Y-%m-%d %H:%M:%S')
217
+ line = f"[{ts}] {msg}"
218
+ print(line, flush=True)
219
+ self.f.write(line + '\n')
220
+ self.f.flush()
221
+ def close(self):
222
+ if self.rank == 0:
223
+ self.f.close()
224
+
225
+ class SWAState:
226
+ def __init__(self):
227
+ self.avg_state = None
228
+ self.n_averaged = 0
229
+ def update(self, model):
230
+ state = {k: v.cpu().float().clone() for k, v in model.module.state_dict().items()}
231
+ if self.avg_state is None:
232
+ self.avg_state = state
233
+ self.n_averaged = 1
234
+ else:
235
+ n = self.n_averaged
236
+ for k in self.avg_state:
237
+ self.avg_state[k] = (self.avg_state[k] * n + state[k]) / (n + 1)
238
+ self.n_averaged += 1
239
+
240
+ # ============ MAIN ============
241
+ def main():
242
+ dist.init_process_group('nccl')
243
+ rank = dist.get_rank()
244
+ world_size = dist.get_world_size()
245
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
246
+ device = torch.device(f'cuda:{local_rank}')
247
+ torch.cuda.set_device(device)
248
+
249
+ effective_batch = BATCH_PER_GPU * GRAD_ACCUM * world_size
250
+ tokens_per_step = effective_batch * MAX_SEQ_LEN
251
+
252
+ logger = Logger(LOG_FILE, rank)
253
+ logger.log(f"=== Multilingual 3.14B Training (Arabic-Rebalanced) ===")
254
+ logger.log(f"World size: {world_size}, Batch/GPU: {BATCH_PER_GPU}, Grad accum: {GRAD_ACCUM}")
255
+ logger.log(f"Effective batch: {effective_batch} seqs = {tokens_per_step:,} tokens/step")
256
+ logger.log(f"Total steps: {TOTAL_STEPS} = {TOTAL_STEPS * tokens_per_step:,} tokens")
257
+ logger.log(f"Schedule: WSD-LINEAR | warmup={WARMUP_STEPS} | stable_end={STABLE_END} | total={TOTAL_STEPS}")
258
+ logger.log(f"AdamW LR={ADAMW_LR}, betas={ADAMW_BETAS}, WD={ADAMW_WD}")
259
+ logger.log(f"Label smoothing={LABEL_SMOOTHING}, min_lr={MIN_LR_RATIO}, grad_clip={GRAD_CLIP}")
260
+ logger.log(f"SWA: start={int(TOTAL_STEPS*SWA_START_FRAC)}, freq={SWA_FREQ}")
261
+ logger.log(f"Model: dim={DIM}, depth={DEPTH}, heads={N_HEADS}, dropout={DROPOUT}")
262
+
263
+ os.makedirs(CKPT_DIR, exist_ok=True)
264
+
265
+ # Data
266
+ logger.log("Loading training data...")
267
+ train_ds = BinaryDataset(f"{DATA_DIR}/train.bin", MAX_SEQ_LEN)
268
+ logger.log(f"Train tokens: {train_ds.n_tokens:,}")
269
+
270
+ logger.log("Loading validation data...")
271
+ val_batches = load_val_data(f"{DATA_DIR}/val.bin", MAX_SEQ_LEN)
272
+ val_lang_batches = {}
273
+ for lang in ['en', 'ar', 'he', 'fa']:
274
+ vpath = f"{DATA_DIR}/val_{lang}.bin"
275
+ if os.path.exists(vpath):
276
+ val_lang_batches[lang] = load_val_data(vpath, MAX_SEQ_LEN)
277
+ logger.log(f" val_{lang}: {len(val_lang_batches[lang])} batches")
278
+
279
+ # Model
280
+ logger.log("Creating model...")
281
+ torch.manual_seed(42)
282
+ model = GPT().to(device)
283
+ n_params = sum(p.numel() for p in model.parameters())
284
+ n_params_no_emb = n_params - model.tok_emb.weight.numel()
285
+ logger.log(f"Model params: {n_params:,} (non-embedding: {n_params_no_emb:,})")
286
+ logger.log(f"GPU memory after model: {torch.cuda.memory_allocated(device)/1e9:.1f} GB")
287
+
288
+ model = DDP(model, device_ids=[local_rank])
289
+ try:
290
+ import bitsandbytes as bnb
291
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WD,
292
+ betas=ADAMW_BETAS, eps=ADAMW_EPS)
293
+ logger.log("Using 8-bit AdamW (bitsandbytes)")
294
+ except ImportError:
295
+ optimizer = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_WD,
296
+ betas=ADAMW_BETAS, eps=ADAMW_EPS)
297
+ logger.log("Using standard AdamW (bitsandbytes not available)")
298
+
299
+ swa = SWAState()
300
+ swa_start_step = int(TOTAL_STEPS * SWA_START_FRAC)
301
+ rng = np.random.default_rng(42 + rank)
302
+ scaler = torch.amp.GradScaler('cuda')
303
+ best_val_bpb = float('inf')
304
+ eval_results = []
305
+ tokens_processed = 0
306
+ start_time = time.time()
307
+
308
+ logger.log("Starting training...")
309
+
310
+ for step in range(1, TOTAL_STEPS + 1):
311
+ model.train()
312
+ lr = wsd_lr_linear(step, TOTAL_STEPS, WARMUP_STEPS, STABLE_END, MIN_LR_RATIO, ADAMW_LR)
313
+ for g in optimizer.param_groups:
314
+ g['lr'] = lr
315
+
316
+ optimizer.zero_grad()
317
+ accum_loss = 0.0
318
+ for micro in range(GRAD_ACCUM):
319
+ x, y = train_ds.get_batch(BATCH_PER_GPU, device, rng)
320
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
321
+ logits = model(x)
322
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1),
323
+ label_smoothing=LABEL_SMOOTHING) / GRAD_ACCUM
324
+ scaler.scale(loss).backward()
325
+ accum_loss += loss.item()
326
+
327
+ scaler.unscale_(optimizer)
328
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
329
+ scaler.step(optimizer)
330
+ scaler.update()
331
+ tokens_processed += tokens_per_step
332
+
333
+ if step >= swa_start_step and step % SWA_FREQ == 0 and rank == 0:
334
+ swa.update(model)
335
+
336
+ if step % LOG_EVERY == 0 and rank == 0:
337
+ elapsed = time.time() - start_time
338
+ tps = tokens_processed / elapsed
339
+ bpb = accum_loss / math.log(2)
340
+ phase = "warmup" if step < WARMUP_STEPS else ("stable" if step < STABLE_END else "decay")
341
+ mem = torch.cuda.max_memory_allocated(device) / 1e9
342
+ logger.log(f"Step {step}/{TOTAL_STEPS} [{phase}] | Loss: {accum_loss:.4f} | "
343
+ f"BPB: {bpb:.4f} | LR: {lr:.6f} | Tokens: {tokens_processed:,} | "
344
+ f"TPS: {tps:,.0f} | SWA: {swa.n_averaged} | Mem: {mem:.1f}GB | {elapsed/60:.1f}min")
345
+
346
+ if step % EVAL_EVERY == 0 or step == TOTAL_STEPS:
347
+ if rank == 0:
348
+ logger.log(f"--- Evaluation at step {step} ---")
349
+ combined_bpb = evaluate(model.module, val_batches, device)
350
+ logger.log(f" Combined val BPB: {combined_bpb:.4f}")
351
+ result = {"step": step, "tokens": tokens_processed, "combined_bpb": combined_bpb}
352
+ for lang, batches in val_lang_batches.items():
353
+ lang_bpb = evaluate(model.module, batches, device)
354
+ result[f"{lang}_bpb"] = lang_bpb
355
+ logger.log(f" {lang} val BPB: {lang_bpb:.4f}")
356
+ eval_results.append(result)
357
+ with open(EVAL_FILE, 'w') as f:
358
+ json.dump(eval_results, f, indent=2)
359
+ if combined_bpb < best_val_bpb:
360
+ best_val_bpb = combined_bpb
361
+ torch.save(model.module.state_dict(), f"{CKPT_DIR}/best_model.pt")
362
+ logger.log(f" New best! BPB: {combined_bpb:.4f}")
363
+ dist.barrier()
364
+
365
+ if step % SAVE_EVERY == 0 and rank == 0:
366
+ ckpt = {
367
+ 'step': step, 'model': model.module.state_dict(),
368
+ 'optimizer': optimizer.state_dict(), 'scaler': scaler.state_dict(),
369
+ 'best_val_bpb': best_val_bpb, 'tokens_processed': tokens_processed,
370
+ 'eval_results': eval_results, 'swa_n': swa.n_averaged,
371
+ }
372
+ torch.save(ckpt, f"{CKPT_DIR}/ckpt_step_{step}.pt")
373
+ logger.log(f"Saved checkpoint at step {step}")
374
+ # Background upload
375
+ os.system(f"aws s3 cp {CKPT_DIR}/best_model.pt s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_best_model.pt --quiet &")
376
+ os.system(f"aws s3 cp {EVAL_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_eval_results.json --quiet &")
377
+ os.system(f"aws s3 cp {LOG_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_training.log --quiet &")
378
+
379
+ # Finalize
380
+ if rank == 0:
381
+ torch.save(model.module.state_dict(), f"{CKPT_DIR}/final_model.pt")
382
+ logger.log("Saved final model")
383
+
384
+ if swa.avg_state is not None and swa.n_averaged > 0:
385
+ logger.log(f"Evaluating SWA model ({swa.n_averaged} checkpoints)...")
386
+ swa_model = GPT().to(device)
387
+ swa_load = {k: v.float().to(device) for k, v in swa.avg_state.items()}
388
+ swa_model.load_state_dict(swa_load)
389
+ swa_bpb = evaluate(swa_model, val_batches, device)
390
+ logger.log(f"SWA model combined BPB: {swa_bpb:.4f} (vs best raw: {best_val_bpb:.4f})")
391
+ swa_result = {"step": "swa", "combined_bpb": swa_bpb, "n_averaged": swa.n_averaged}
392
+ for lang, batches in val_lang_batches.items():
393
+ lang_bpb = evaluate(swa_model, batches, device)
394
+ swa_result[f"{lang}_bpb"] = lang_bpb
395
+ logger.log(f" SWA {lang} BPB: {lang_bpb:.4f}")
396
+ eval_results.append(swa_result)
397
+ torch.save(swa_load, f"{CKPT_DIR}/swa_model.pt")
398
+ with open(EVAL_FILE, 'w') as f:
399
+ json.dump(eval_results, f, indent=2)
400
+ del swa_model
401
+
402
+ # Final S3 upload
403
+ logger.log("Uploading all artifacts to S3...")
404
+ os.system(f"aws s3 sync {CKPT_DIR}/ s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}/")
405
+ os.system(f"aws s3 cp {LOG_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_training.log")
406
+ os.system(f"aws s3 cp {EVAL_FILE} s3://{S3_BUCKET}/{S3_PREFIX}/checkpoints/{VERSION}_eval_results.json")
407
+
408
+ elapsed = time.time() - start_time
409
+ logger.log(f"=== Training complete! Total time: {elapsed/3600:.2f}h ===")
410
+ logger.log(f"Best combined BPB: {best_val_bpb:.4f}")
411
+ logger.log(f"Total tokens: {tokens_processed:,}")
412
+
413
+ logger.close()
414
+ dist.destroy_process_group()
415
+
416
+ if __name__ == '__main__':
417
+ main()