algorythmtechnologies commited on
Commit
6ce3b41
·
verified ·
1 Parent(s): 8174855

Update supernova/train.py

Browse files
Files changed (1) hide show
  1. supernova/train.py +415 -159
supernova/train.py CHANGED
@@ -1,159 +1,415 @@
1
- import argparse
2
- import json
3
- import math
4
- import os
5
- import time
6
- from typing import Optional
7
-
8
- import torch
9
- import torch.nn as nn
10
- from torch.utils.data import DataLoader
11
- from transformers import get_cosine_schedule_with_warmup
12
-
13
- from .config import ModelConfig
14
- from .model import SupernovaModel
15
- from .tokenizer import load_gpt2_tokenizer
16
- from .data import load_sources_from_yaml, TokenChunkDataset
17
-
18
-
19
- def compute_grad_norm(model: nn.Module) -> float:
20
- total = 0.0
21
- for p in model.parameters():
22
- if p.grad is not None:
23
- param_norm = p.grad.data.float().norm(2).item()
24
- total += param_norm * param_norm
25
- return math.sqrt(total)
26
-
27
-
28
- def train(
29
- config_path: str,
30
- data_config_path: str,
31
- seq_len: int = 1024,
32
- batch_size: int = 16,
33
- grad_accum: int = 8,
34
- lr: float = 3e-4,
35
- warmup_steps: int = 2000,
36
- max_steps: int = 100_000,
37
- save_every: int = 10_000,
38
- out_dir: str = "checkpoints",
39
- seed: int = 42,
40
- ):
41
- torch.manual_seed(seed)
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
-
44
- cfg = ModelConfig.from_json_file(config_path)
45
- # Assert exact parameter budget from formula
46
- cfg.assert_exact_params(expected=25_000_000)
47
-
48
- tok = load_gpt2_tokenizer()
49
- assert tok.vocab_size == cfg.vocab_size, (
50
- f"Tokenizer vocab size ({tok.vocab_size}) != config ({cfg.vocab_size})"
51
- )
52
-
53
- model = SupernovaModel(cfg).to(device)
54
-
55
- # Double-check exact parameter count by instantiating
56
- total_params = sum(p.numel() for p in model.parameters())
57
- assert total_params == 25_000_000, f"Model has {total_params} params, expected 25,000,000"
58
-
59
- sources = load_sources_from_yaml(data_config_path)
60
- ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
61
- dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
62
-
63
- optimizer = torch.optim.AdamW(
64
- model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
65
- )
66
-
67
- # We use a token-based schedule; max_steps is optimizer steps, not micro-steps
68
- scheduler = get_cosine_schedule_with_warmup(
69
- optimizer,
70
- num_warmup_steps=warmup_steps,
71
- num_training_steps=max_steps,
72
- )
73
-
74
- model.train()
75
- os.makedirs(out_dir, exist_ok=True)
76
-
77
- step = 0
78
- micro = 0
79
- running_loss = 0.0
80
- t0 = time.time()
81
-
82
- while step < max_steps:
83
- for batch in dl:
84
- x, y = batch
85
- x = x.to(device)
86
- y = y.to(device)
87
-
88
- logits, loss = model(x, y)
89
- loss = loss / grad_accum
90
- loss.backward()
91
-
92
- micro += 1
93
- running_loss += loss.item()
94
-
95
- if micro % grad_accum == 0:
96
- # Optional clip: leave off by default for pure monitoring
97
- # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
98
- optimizer.step()
99
- optimizer.zero_grad(set_to_none=True)
100
- scheduler.step()
101
-
102
- step += 1
103
- if step % 50 == 0:
104
- grad_norm = compute_grad_norm(model)
105
- avg_loss = running_loss * grad_accum / 50.0
106
- running_loss = 0.0
107
- elapsed = time.time() - t0
108
- lr_now = scheduler.get_last_lr()[0]
109
- print(f"step={step} loss={avg_loss:.4f} grad_norm={grad_norm:.2f} lr={lr_now:.6f} elapsed={elapsed:.1f}s")
110
- t0 = time.time()
111
-
112
- if save_every and step % save_every == 0:
113
- ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
114
- torch.save({
115
- "model_state_dict": model.state_dict(),
116
- "config": cfg.__dict__,
117
- "step": step,
118
- }, ckpt_path)
119
-
120
- if step >= max_steps:
121
- break
122
-
123
- # final save
124
- ckpt_path = os.path.join(out_dir, f"supernova_final.pt")
125
- torch.save({
126
- "model_state_dict": model.state_dict(),
127
- "config": cfg.__dict__,
128
- "step": step,
129
- }, ckpt_path)
130
-
131
-
132
- if __name__ == "__main__":
133
- ap = argparse.ArgumentParser()
134
- ap.add_argument("--config", required=True)
135
- ap.add_argument("--data-config", required=True)
136
- ap.add_argument("--seq-len", type=int, default=1024)
137
- ap.add_argument("--batch-size", type=int, default=16)
138
- ap.add_argument("--grad-accum", type=int, default=8)
139
- ap.add_argument("--lr", type=float, default=3e-4)
140
- ap.add_argument("--warmup-steps", type=int, default=2000)
141
- ap.add_argument("--max-steps", type=int, default=100000)
142
- ap.add_argument("--save-every", type=int, default=10000)
143
- ap.add_argument("--out-dir", type=str, default="checkpoints")
144
- ap.add_argument("--seed", type=int, default=42)
145
- args = ap.parse_args()
146
-
147
- train(
148
- config_path=args.config,
149
- data_config_path=args.data_config,
150
- seq_len=args.seq_len,
151
- batch_size=args.batch_size,
152
- grad_accum=args.grad_accum,
153
- lr=args.lr,
154
- warmup_steps=args.warmup_steps,
155
- max_steps=args.max_steps,
156
- save_every=args.save_every,
157
- out_dir=args.out_dir,
158
- seed=args.seed,
159
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py (improved)
2
+ import argparse
3
+ import json
4
+ import math
5
+ import os
6
+ import time
7
+ from typing import Optional, Dict, Any
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils.data import DataLoader, DistributedSampler
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ from transformers import get_cosine_schedule_with_warmup
14
+
15
+ from .config import ModelConfig
16
+ from .model import SupernovaModel
17
+ from .tokenizer import load_gpt2_tokenizer
18
+ from .data import load_sources_from_yaml, TokenChunkDataset
19
+
20
+ # -----------------------
21
+ # Utilities
22
+ # -----------------------
23
+ def compute_grad_norm(model: nn.Module) -> float:
24
+ total = 0.0
25
+ for p in model.parameters():
26
+ if p.grad is not None:
27
+ param_norm = p.grad.data.float().norm(2).item()
28
+ total += param_norm * param_norm
29
+ return math.sqrt(total)
30
+
31
+ def atomic_save(obj: Dict[str, Any], path: str):
32
+ tmp = path + ".tmp"
33
+ torch.save(obj, tmp)
34
+ os.replace(tmp, path)
35
+
36
+ class EMA:
37
+ """Simple exponential moving average of model params (maintains shadow copy)."""
38
+ def __init__(self, model: nn.Module, decay: float = 0.9999):
39
+ self.decay = decay
40
+ self.shadow = {}
41
+ for name, p in model.named_parameters():
42
+ if p.requires_grad:
43
+ self.shadow[name] = p.data.clone()
44
+
45
+ def update(self, model: nn.Module):
46
+ for name, p in model.named_parameters():
47
+ if p.requires_grad:
48
+ self.shadow[name].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
49
+
50
+ def store(self, model: nn.Module):
51
+ self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
52
+
53
+ def copy_to(self, model: nn.Module):
54
+ for name, p in model.named_parameters():
55
+ if p.requires_grad:
56
+ p.data.copy_(self.shadow[name])
57
+
58
+ def restore(self, model: nn.Module):
59
+ for name, p in model.named_parameters():
60
+ if p.requires_grad:
61
+ p.data.copy_(self.backup[name])
62
+ del self.backup
63
+
64
+ # -----------------------
65
+ # Training loop
66
+ # -----------------------
67
+ def train(
68
+ config_path: str,
69
+ data_config_path: str,
70
+ seq_len: int = 1024,
71
+ batch_size: int = 16,
72
+ grad_accum: int = 8,
73
+ lr: float = 3e-4,
74
+ warmup_steps: int = 2000,
75
+ max_steps: int = 100_000,
76
+ save_every: int = 10_000,
77
+ out_dir: str = "checkpoints",
78
+ seed: int = 42,
79
+ validate_every: int = 1000,
80
+ val_steps: int = 100,
81
+ clip_grad_norm: Optional[float] = 1.0,
82
+ use_ema: bool = True,
83
+ ema_decay: float = 0.9999,
84
+ resume_from: Optional[str] = None,
85
+ use_tensorboard: bool = True,
86
+ ddp: bool = False,
87
+ local_rank: int = 0,
88
+ num_workers: int = 4,
89
+ pin_memory: bool = True,
90
+ compile_model: bool = False,
91
+ ):
92
+ # reproducibility
93
+ torch.manual_seed(seed)
94
+ torch.cuda.manual_seed_all(seed)
95
+ import random
96
+ random.seed(seed)
97
+ # performance flags
98
+ torch.backends.cudnn.benchmark = True
99
+
100
+ # device / distributed
101
+ if ddp:
102
+ torch.distributed.init_process_group(backend="nccl")
103
+ device = torch.device(f"cuda:{local_rank}")
104
+ torch.cuda.set_device(device)
105
+ else:
106
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
+
108
+ # config & tokenizer
109
+ cfg = ModelConfig.from_json_file(config_path)
110
+ cfg.assert_exact_params(expected=25_000_000)
111
+ tok = load_gpt2_tokenizer()
112
+ assert tok.vocab_size == cfg.vocab_size, "Tokenizer vocab size mismatch."
113
+
114
+ model = SupernovaModel(cfg)
115
+ # optional: enable gradient checkpointing for memory saving if model supports it
116
+ if hasattr(model, "gradient_checkpointing_enable"):
117
+ try:
118
+ model.gradient_checkpointing_enable()
119
+ except Exception:
120
+ pass
121
+
122
+ model.to(device)
123
+
124
+ # double-check params
125
+ total_params = sum(p.numel() for p in model.parameters())
126
+ assert total_params == 25_000_000, f"Model has {total_params} params, expected 25,000,000"
127
+
128
+ # optional compile (PyTorch 2.0)
129
+ if compile_model:
130
+ try:
131
+ model = torch.compile(model)
132
+ except Exception as e:
133
+ print("torch.compile not available/failed:", e)
134
+
135
+ # DDP wrap
136
+ if ddp:
137
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False)
138
+
139
+ # dataset and dataloader
140
+ sources = load_sources_from_yaml(data_config_path)
141
+ # TODO: improve TokenChunkDataset to perform token-packing (pack multiple short examples into one sequence)
142
+ ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
143
+
144
+ sampler = DistributedSampler(ds) if ddp else None
145
+ dl = DataLoader(
146
+ ds,
147
+ batch_size=batch_size,
148
+ shuffle=(sampler is None),
149
+ sampler=sampler,
150
+ num_workers=num_workers,
151
+ pin_memory=pin_memory,
152
+ prefetch_factor=2,
153
+ drop_last=True,
154
+ )
155
+
156
+ # optimizer with simple parameter grouping example to avoid weight decay on norms/bias
157
+ def param_groups(model):
158
+ decay, no_decay = [], []
159
+ for n, p in model.named_parameters():
160
+ if not p.requires_grad:
161
+ continue
162
+ if any(nd in n for nd in ["bias", "ln", "layernorm", "LayerNorm", "norm"]):
163
+ no_decay.append(p)
164
+ else:
165
+ decay.append(p)
166
+ return [
167
+ {"params": decay, "weight_decay": 0.1},
168
+ {"params": no_decay, "weight_decay": 0.0},
169
+ ]
170
+
171
+ optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
172
+
173
+ # scheduler
174
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
175
+
176
+ # AMP scaler
177
+ scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
178
+
179
+ # EMA
180
+ ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
181
+
182
+ # logging + checkpoint dir
183
+ os.makedirs(out_dir, exist_ok=True)
184
+ writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
185
+
186
+ # validation dataset (simple split: user should provide a separate validation YAML ideally)
187
+ # TODO: Implement a proper validation dataset pipeline. For now, we use a small random subset of training data.
188
+ val_ds = None
189
+ val_dl = None
190
+
191
+ # resume
192
+ start_step = 0
193
+ best_val_loss = float("inf")
194
+ if resume_from and os.path.exists(resume_from):
195
+ ckpt = torch.load(resume_from, map_location=device)
196
+ model_state = ckpt["model_state_dict"]
197
+ # if ddp, load into module
198
+ target = model.module if ddp else model
199
+ target.load_state_dict(model_state)
200
+ optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
201
+ scheduler_state = ckpt.get("scheduler_state_dict", None)
202
+ if scheduler_state:
203
+ scheduler.load_state_dict(scheduler_state)
204
+ if "scaler_state_dict" in ckpt and scaler is not None:
205
+ scaler.load_state_dict(ckpt["scaler_state_dict"])
206
+ start_step = ckpt.get("step", 0)
207
+ best_val_loss = ckpt.get("best_val_loss", best_val_loss)
208
+ print(f"Resumed from {resume_from} at step {start_step}")
209
+
210
+ model.train()
211
+ step = start_step
212
+ micro = 0
213
+ running_loss = 0.0
214
+ t0 = time.time()
215
+ no_improve_steps = 0
216
+ early_stop_patience = 10_000 # you can tune this
217
+
218
+ # training loop
219
+ while step < max_steps:
220
+ if sampler is not None:
221
+ sampler.set_epoch(step) # shuffle differently per epoch for DDP
222
+
223
+ for batch in dl:
224
+ x, y = batch
225
+ x = x.to(device, non_blocking=True)
226
+ y = y.to(device, non_blocking=True)
227
+
228
+ with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
229
+ logits, loss = model(x, y)
230
+ loss = loss / grad_accum
231
+
232
+ scaler.scale(loss).backward()
233
+ micro += 1
234
+ running_loss += loss.item()
235
+
236
+ if micro % grad_accum == 0:
237
+ # gradient clipping
238
+ if clip_grad_norm is not None:
239
+ scaler.unscale_(optimizer)
240
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
241
+
242
+ scaler.step(optimizer)
243
+ scaler.update()
244
+ optimizer.zero_grad(set_to_none=True)
245
+ scheduler.step()
246
+
247
+ if ema:
248
+ ema.update(model if not ddp else model.module)
249
+
250
+ step += 1
251
+
252
+ # logging
253
+ if step % 50 == 0 and (not ddp or local_rank == 0):
254
+ grad_norm = compute_grad_norm(model if not ddp else model.module)
255
+ avg_loss = running_loss * grad_accum / 50.0
256
+ running_loss = 0.0
257
+ elapsed = time.time() - t0
258
+ lr_now = scheduler.get_last_lr()[0]
259
+ print(f"step={step} loss={avg_loss:.6f} grad_norm={grad_norm:.3f} lr={lr_now:.6f} elapsed={elapsed:.1f}s")
260
+ if writer:
261
+ writer.add_scalar("train/loss", avg_loss, step)
262
+ writer.add_scalar("train/grad_norm", grad_norm, step)
263
+ writer.add_scalar("train/lr", lr_now, step)
264
+ t0 = time.time()
265
+
266
+ # periodic validation
267
+ if validate_every and step % validate_every == 0:
268
+ if val_dl is None:
269
+ # quick in-memory val split: take first N batches (user should replace with real val)
270
+ # NOTE: for production, create a dedicated validation dataset.
271
+ val_ds = TokenChunkDataset(tok, sources[: max(1, len(sources) // 20)], seq_len=seq_len, eos_token_id=tok.eos_token_id)
272
+ val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
273
+
274
+ model.eval()
275
+ # optionally swap in EMA weights for evaluation
276
+ if ema:
277
+ ema.store(model if not ddp else model.module)
278
+ ema.copy_to(model if not ddp else model.module)
279
+
280
+ val_losses = []
281
+ with torch.no_grad():
282
+ for i, (vx, vy) in enumerate(val_dl):
283
+ if i >= val_steps:
284
+ break
285
+ vx = vx.to(device)
286
+ vy = vy.to(device)
287
+ with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
288
+ _, vloss = model(vx, vy)
289
+ val_losses.append(float(vloss.detach().cpu().item()))
290
+ mean_val = float(sum(val_losses) / max(1, len(val_losses)))
291
+ if writer and (not ddp or local_rank == 0):
292
+ writer.add_scalar("val/loss", mean_val, step)
293
+ print(f"[eval] step={step} val_loss={mean_val:.6f}")
294
+
295
+ # restore weights
296
+ if ema:
297
+ ema.restore(model if not ddp else model.module)
298
+ model.train()
299
+
300
+ # early stop / best model saving
301
+ if mean_val < best_val_loss:
302
+ best_val_loss = mean_val
303
+ no_improve_steps = 0
304
+ best_path = os.path.join(out_dir, f"supernova_best_step{step}.pt")
305
+ ckpt = {
306
+ "model_state_dict": (model.module.state_dict() if ddp else model.state_dict()),
307
+ "optimizer_state_dict": optimizer.state_dict(),
308
+ "scheduler_state_dict": scheduler.state_dict(),
309
+ "scaler_state_dict": (scaler.state_dict() if scaler else None),
310
+ "step": step,
311
+ "best_val_loss": best_val_loss,
312
+ "config": cfg.__dict__,
313
+ }
314
+ if not ddp or local_rank == 0:
315
+ atomic_save(ckpt, best_path)
316
+ print(f"Saved best checkpoint to {best_path}")
317
+ else:
318
+ no_improve_steps += validate_every
319
+ if no_improve_steps >= early_stop_patience:
320
+ print("Early stopping triggered.")
321
+ step = max_steps
322
+ break
323
+
324
+ # periodic checkpointing
325
+ if save_every and step % save_every == 0 and (not ddp or local_rank == 0):
326
+ ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
327
+ ckpt = {
328
+ "model_state_dict": (model.module.state_dict() if ddp else model.state_dict()),
329
+ "optimizer_state_dict": optimizer.state_dict(),
330
+ "scheduler_state_dict": scheduler.state_dict(),
331
+ "scaler_state_dict": (scaler.state_dict() if scaler else None),
332
+ "step": step,
333
+ "best_val_loss": best_val_loss,
334
+ "config": cfg.__dict__,
335
+ }
336
+ atomic_save(ckpt, ckpt_path)
337
+ print(f"Saved checkpoint {ckpt_path}")
338
+
339
+ if step >= max_steps:
340
+ break
341
+
342
+ if step >= max_steps:
343
+ break
344
+
345
+ # final save
346
+ if not ddp or local_rank == 0:
347
+ ckpt_path = os.path.join(out_dir, f"supernova_final_step{step}.pt")
348
+ ckpt = {
349
+ "model_state_dict": (model.module.state_dict() if ddp else model.state_dict()),
350
+ "optimizer_state_dict": optimizer.state_dict(),
351
+ "scheduler_state_dict": scheduler.state_dict(),
352
+ "scaler_state_dict": (scaler.state_dict() if scaler else None),
353
+ "step": step,
354
+ "best_val_loss": best_val_loss,
355
+ "config": cfg.__dict__,
356
+ }
357
+ atomic_save(ckpt, ckpt_path)
358
+ print(f"Saved final checkpoint to {ckpt_path}")
359
+
360
+ if writer:
361
+ writer.close()
362
+
363
+
364
+ if __name__ == "__main__":
365
+ ap = argparse.ArgumentParser()
366
+ ap.add_argument("--config", required=True)
367
+ ap.add_argument("--data-config", required=True)
368
+ ap.add_argument("--seq-len", type=int, default=1024)
369
+ ap.add_argument("--batch-size", type=int, default=16)
370
+ ap.add_argument("--grad-accum", type=int, default=8)
371
+ ap.add_argument("--lr", type=float, default=3e-4)
372
+ ap.add_argument("--warmup-steps", type=int, default=2000)
373
+ ap.add_argument("--max-steps", type=int, default=100000)
374
+ ap.add_argument("--save-every", type=int, default=10000)
375
+ ap.add_argument("--out-dir", type=str, default="checkpoints")
376
+ ap.add_argument("--seed", type=int, default=42)
377
+ ap.add_argument("--validate-every", type=int, default=1000)
378
+ ap.add_argument("--val-steps", type=int, default=100)
379
+ ap.add_argument("--clip-grad-norm", type=float, default=1.0)
380
+ ap.add_argument("--resume-from", type=str, default=None)
381
+ ap.add_argument("--use-ema", action="store_true")
382
+ ap.add_argument("--ema-decay", type=float, default=0.9999)
383
+ ap.add_argument("--no-tensorboard", dest="use_tensorboard", action="store_false")
384
+ ap.add_argument("--ddp", action="store_true", help="enable DistributedDataParallel; use with torchrun")
385
+ ap.add_argument("--local-rank", type=int, default=0)
386
+ ap.add_argument("--num-workers", type=int, default=4)
387
+ ap.add_argument("--pin-memory", type=bool, default=True)
388
+ ap.add_argument("--compile", dest="compile_model", action="store_true")
389
+ args = ap.parse_args()
390
+
391
+ train(
392
+ config_path=args.config,
393
+ data_config_path=args.data_config,
394
+ seq_len=args.seq_len,
395
+ batch_size=args.batch_size,
396
+ grad_accum=args.grad_accum,
397
+ lr=args.lr,
398
+ warmup_steps=args.warmup_steps,
399
+ max_steps=args.max_steps,
400
+ save_every=args.save_every,
401
+ out_dir=args.out_dir,
402
+ seed=args.seed,
403
+ validate_every=args.validate_every,
404
+ val_steps=args.val_steps,
405
+ clip_grad_norm=args.clip_grad_norm,
406
+ use_ema=args.use_ema,
407
+ ema_decay=args.ema_decay,
408
+ resume_from=args.resume_from,
409
+ use_tensorboard=args.use_tensorboard,
410
+ ddp=args.ddp,
411
+ local_rank=args.local_rank,
412
+ num_workers=args.num_workers,
413
+ pin_memory=args.pin_memory,
414
+ compile_model=args.compile_model,
415
+ )