algorythmtechnologies commited on
Commit
333e53d
·
verified ·
1 Parent(s): 2c2074d

Update supernova/train.py

Browse files
Files changed (1) hide show
  1. supernova/train.py +14 -109
supernova/train.py CHANGED
@@ -10,6 +10,7 @@ import torch.nn as nn
10
  from torch.utils.data import DataLoader, DistributedSampler
11
  from torch.utils.tensorboard import SummaryWriter
12
  from transformers import get_cosine_schedule_with_warmup
 
13
 
14
  from .config import ModelConfig
15
  from .model import SupernovaModel
@@ -45,6 +46,16 @@ def atomic_save(obj: Dict[str, Any], path: str):
45
  torch.save(obj, tmp)
46
  os.replace(tmp, path)
47
 
 
 
 
 
 
 
 
 
 
 
48
  class EMA:
49
  """Simple exponential moving average of model params (maintains shadow copy)."""
50
  def __init__(self, model: nn.Module, decay: float = 0.9999):
@@ -100,6 +111,7 @@ def train(
100
  num_workers: int = 4,
101
  pin_memory: bool = True,
102
  compile_model: bool = False,
 
103
  ):
104
  # reproducibility
105
  torch.manual_seed(seed)
@@ -333,113 +345,6 @@ def train(
333
  best_val_loss = mean_val
334
  no_improve_steps = 0
335
  best_path = os.path.join(out_dir, f"supernova_best_step{step}.pt")
 
336
  ckpt = {
337
- "model_state_dict": (model.module.state_dict() if ddp else model.state_dict()),
338
- "optimizer_state_dict": optimizer.state_dict(),
339
- "scheduler_state_dict": scheduler.state_dict(),
340
- "scaler_state_dict": (scaler.state_dict() if scaler else None),
341
- "step": step,
342
- "best_val_loss": best_val_loss,
343
- "config": cfg.__dict__,
344
- }
345
- if not ddp or local_rank == 0:
346
- atomic_save(ckpt, best_path)
347
- print(f"Saved best checkpoint to {best_path}")
348
- else:
349
- no_improve_steps += validate_every
350
- if no_improve_steps >= early_stop_patience:
351
- print("Early stopping triggered.")
352
- step = max_steps
353
- break
354
-
355
- # periodic checkpointing
356
- if save_every and step % save_every == 0 and (not ddp or local_rank == 0):
357
- ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
358
- ckpt = {
359
- "model_state_dict": (model.module.state_dict() if ddp else model.state_dict()),
360
- "optimizer_state_dict": optimizer.state_dict(),
361
- "scheduler_state_dict": scheduler.state_dict(),
362
- "scaler_state_dict": (scaler.state_dict() if scaler else None),
363
- "step": step,
364
- "best_val_loss": best_val_loss,
365
- "config": cfg.__dict__,
366
- }
367
- atomic_save(ckpt, ckpt_path)
368
- print(f"Saved checkpoint {ckpt_path}")
369
-
370
- if step >= max_steps:
371
- break
372
-
373
- if step >= max_steps:
374
- break
375
-
376
- # final save
377
- if not ddp or local_rank == 0:
378
- ckpt_path = os.path.join(out_dir, f"supernova_final_step{step}.pt")
379
- ckpt = {
380
- "model_state_dict": (model.module.state_dict() if ddp else model.state_dict()),
381
- "optimizer_state_dict": optimizer.state_dict(),
382
- "scheduler_state_dict": scheduler.state_dict(),
383
- "scaler_state_dict": (scaler.state_dict() if scaler else None),
384
- "step": step,
385
- "best_val_loss": best_val_loss,
386
- "config": cfg.__dict__,
387
- }
388
- atomic_save(ckpt, ckpt_path)
389
- print(f"Saved final checkpoint to {ckpt_path}")
390
-
391
- if writer:
392
- writer.close()
393
-
394
- if __name__ == "__main__":
395
- ap = argparse.ArgumentParser()
396
- ap.add_argument("--config", required=True)
397
- ap.add_argument("--data-config", required=True)
398
- ap.add_argument("--seq-len", type=int, default=1024)
399
- ap.add_argument("--batch-size", type=int, default=16)
400
- ap.add_argument("--grad-accum", type=int, default=8)
401
- ap.add_argument("--lr", type=float, default=3e-4)
402
- ap.add_argument("--warmup-steps", type=int, default=2000)
403
- ap.add_argument("--max-steps", type=int, default=100000)
404
- ap.add_argument("--save-every", type=int, default=10000)
405
- ap.add_argument("--out-dir", type=str, default="checkpoints")
406
- ap.add_argument("--seed", type=int, default=42)
407
- ap.add_argument("--validate-every", type=int, default=1000)
408
- ap.add_argument("--val-steps", type=int, default=100)
409
- ap.add_argument("--clip-grad-norm", type=float, default=1.0)
410
- ap.add_argument("--resume-from", type=str, default=None)
411
- ap.add_argument("--use-ema", action="store_true")
412
- ap.add_argument("--ema-decay", type=float, default=0.9999)
413
- ap.add_argument("--no-tensorboard", dest="use_tensorboard", action="store_false")
414
- ap.add_argument("--ddp", action="store_true", help="enable DistributedDataParallel; use with torchrun")
415
- ap.add_argument("--local-rank", type=int, default=0)
416
- ap.add_argument("--num-workers", type=int, default=4)
417
- ap.add_argument("--pin-memory", type=bool, default=True)
418
- ap.add_argument("--compile", dest="compile_model", action="store_true")
419
- args = ap.parse_args()
420
-
421
- train(
422
- config_path=args.config,
423
- data_config_path=args.data_config,
424
- seq_len=args.seq_len,
425
- batch_size=args.batch_size,
426
- grad_accum=args.grad_accum,
427
- lr=args.lr,
428
- warmup_steps=args.warmup_steps,
429
- max_steps=args.max_steps,
430
- save_every=args.save_every,
431
- out_dir=args.out_dir,
432
- seed=args.seed,
433
- validate_every=args.validate_every,
434
- val_steps=args.val_steps,
435
- clip_grad_norm=args.clip_grad_norm,
436
- use_ema=args.use_ema,
437
- ema_decay=args.ema_decay,
438
- resume_from=args.resume_from,
439
- use_tensorboard=args.use_tensorboard,
440
- ddp=args.ddp,
441
- local_rank=args.local_rank,
442
- num_workers=args.num_workers,
443
- pin_memory=args.pin_memory,
444
- compile_model=args.compile_model,
445
- )
 
10
  from torch.utils.data import DataLoader, DistributedSampler
11
  from torch.utils.tensorboard import SummaryWriter
12
  from transformers import get_cosine_schedule_with_warmup
13
+ from safetensors.torch import save_file
14
 
15
  from .config import ModelConfig
16
  from .model import SupernovaModel
 
46
  torch.save(obj, tmp)
47
  os.replace(tmp, path)
48
 
49
+ def save_safetensors(model_state_dict: Dict[str, torch.Tensor], path: str):
50
+ """Save model weights in safetensors format."""
51
+ try:
52
+ tmp = path + ".tmp"
53
+ save_file(model_state_dict, tmp)
54
+ os.replace(tmp, path)
55
+ print(f"Saved safetensors to {path}")
56
+ except Exception as e:
57
+ print(f"Warning: Failed to save safetensors: {e}")
58
+
59
  class EMA:
60
  """Simple exponential moving average of model params (maintains shadow copy)."""
61
  def __init__(self, model: nn.Module, decay: float = 0.9999):
 
111
  num_workers: int = 4,
112
  pin_memory: bool = True,
113
  compile_model: bool = False,
114
+ save_safetensors: bool = True,
115
  ):
116
  # reproducibility
117
  torch.manual_seed(seed)
 
345
  best_val_loss = mean_val
346
  no_improve_steps = 0
347
  best_path = os.path.join(out_dir, f"supernova_best_step{step}.pt")
348
+ model_state = model.module.state_dict() if ddp else model.state_dict()
349
  ckpt = {
350
+ "model_state_dict": model_state