Upload main.ipynb with huggingface_hub
Browse files- main.ipynb +3 -4
main.ipynb
CHANGED
|
@@ -887,9 +887,8 @@
|
|
| 887 |
"print(f\"Checkpoint was saved at step: {ckpt['step']}\")\n",
|
| 888 |
"\n",
|
| 889 |
"# Load model weights\n",
|
| 890 |
-
"# Load into unwrapped model (
|
| 891 |
-
"
|
| 892 |
-
"load_target.load_state_dict(ckpt['model_state_dict'])\n",
|
| 893 |
"print(\"Model weights loaded\")\n",
|
| 894 |
"\n",
|
| 895 |
"# Load EMA weights\n",
|
|
@@ -900,7 +899,7 @@
|
|
| 900 |
"resume_step = ckpt['step']\n",
|
| 901 |
"if 'optimizer_state_dict' in ckpt:\n",
|
| 902 |
" optimizer = torch.optim.AdamW(\n",
|
| 903 |
-
"
|
| 904 |
" lr=config.learning_rate,\n",
|
| 905 |
" betas=(0.9, 0.98),\n",
|
| 906 |
" weight_decay=config.weight_decay,\n",
|
|
|
|
| 887 |
"print(f\"Checkpoint was saved at step: {ckpt['step']}\")\n",
|
| 888 |
"\n",
|
| 889 |
"# Load model weights\n",
|
| 890 |
+
"# Load into unwrapped model (model_unwrapped set in cell 10)\n",
|
| 891 |
+
"model_unwrapped.load_state_dict(ckpt['model_state_dict'])\n",
|
|
|
|
| 892 |
"print(\"Model weights loaded\")\n",
|
| 893 |
"\n",
|
| 894 |
"# Load EMA weights\n",
|
|
|
|
| 899 |
"resume_step = ckpt['step']\n",
|
| 900 |
"if 'optimizer_state_dict' in ckpt:\n",
|
| 901 |
" optimizer = torch.optim.AdamW(\n",
|
| 902 |
+
" model_unwrapped.parameters(),\n",
|
| 903 |
" lr=config.learning_rate,\n",
|
| 904 |
" betas=(0.9, 0.98),\n",
|
| 905 |
" weight_decay=config.weight_decay,\n",
|