OpenTransformer commited on
Commit
9dd1da1
·
verified ·
1 Parent(s): a27e604

Fix GradScaler resume bug - wrapped scaler.load_state_dict() in try/except at line 512. Allows resuming from checkpoints saved without AMP.

Browse files
Files changed (1) hide show
  1. n.py +5 -1
n.py CHANGED
@@ -507,7 +507,11 @@ def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
507
  ar_h.load_state_dict(ck["ar"])
508
  sat_h.load_state_dict(ck["sat"])
509
  opt.load_state_dict(ck["opt"])
510
- scaler.load_state_dict(ck["scaler"])
 
 
 
 
511
  return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
512
 
513
  def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
 
507
  ar_h.load_state_dict(ck["ar"])
508
  sat_h.load_state_dict(ck["sat"])
509
  opt.load_state_dict(ck["opt"])
510
+ try:
511
+ if ck.get("scaler"):
512
+ scaler.load_state_dict(ck["scaler"])
513
+ except RuntimeError:
514
+ print("[warn] Could not load scaler state, starting fresh")
515
  return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
516
 
517
  def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):