chipling commited on
Commit
d5df505
·
verified ·
1 Parent(s): 1c89b07

Upload main.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.ipynb +3 -4
main.ipynb CHANGED
@@ -887,9 +887,8 @@
887
  "print(f\"Checkpoint was saved at step: {ckpt['step']}\")\n",
888
  "\n",
889
  "# Load model weights\n",
890
- "# Load into unwrapped model (checkpoint saved without DataParallel 'module.' prefix)\n",
891
- "load_target = model.module if hasattr(model, 'module') else model\n",
892
- "load_target.load_state_dict(ckpt['model_state_dict'])\n",
893
  "print(\"Model weights loaded\")\n",
894
  "\n",
895
  "# Load EMA weights\n",
@@ -900,7 +899,7 @@
900
  "resume_step = ckpt['step']\n",
901
  "if 'optimizer_state_dict' in ckpt:\n",
902
  " optimizer = torch.optim.AdamW(\n",
903
- " model.parameters(),\n",
904
  " lr=config.learning_rate,\n",
905
  " betas=(0.9, 0.98),\n",
906
  " weight_decay=config.weight_decay,\n",
 
887
  "print(f\"Checkpoint was saved at step: {ckpt['step']}\")\n",
888
  "\n",
889
  "# Load model weights\n",
890
+ "# Load into unwrapped model (model_unwrapped set in cell 10)\n",
891
+ "model_unwrapped.load_state_dict(ckpt['model_state_dict'])\n",
 
892
  "print(\"Model weights loaded\")\n",
893
  "\n",
894
  "# Load EMA weights\n",
 
899
  "resume_step = ckpt['step']\n",
900
  "if 'optimizer_state_dict' in ckpt:\n",
901
  " optimizer = torch.optim.AdamW(\n",
902
+ " model_unwrapped.parameters(),\n",
903
  " lr=config.learning_rate,\n",
904
  " betas=(0.9, 0.98),\n",
905
  " weight_decay=config.weight_decay,\n",