Fix GradScaler resume bug - wrapped scaler.load_state_dict() in try/except at line 512. Allows resuming from checkpoints saved without AMP.
Browse files
n.py
CHANGED
|
@@ -507,7 +507,11 @@ def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
|
|
| 507 |
ar_h.load_state_dict(ck["ar"])
|
| 508 |
sat_h.load_state_dict(ck["sat"])
|
| 509 |
opt.load_state_dict(ck["opt"])
|
| 510 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
|
| 512 |
|
| 513 |
def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
|
|
|
|
| 507 |
ar_h.load_state_dict(ck["ar"])
|
| 508 |
sat_h.load_state_dict(ck["sat"])
|
| 509 |
opt.load_state_dict(ck["opt"])
|
| 510 |
+
try:
|
| 511 |
+
if ck.get("scaler"):
|
| 512 |
+
scaler.load_state_dict(ck["scaler"])
|
| 513 |
+
except RuntimeError:
|
| 514 |
+
print("[warn] Could not load scaler state, starting fresh")
|
| 515 |
return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
|
| 516 |
|
| 517 |
def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
|