Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
import os
import torch
from typing import Optional
def save_checkpoint(checkpoint_dir: str, epoch: int, model, optimizer, scheduler, val_loss: float, config) -> str:
os.makedirs(checkpoint_dir, exist_ok=True)
path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
payload = {
"epoch": epoch,
"model_state_dict": model.state_dict() if hasattr(model, "state_dict") else model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
"scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
"val_loss": val_loss,
"config": config,
}
torch.save(payload, path)
return path
def load_checkpoint(path: str, model, optimizer=None, scheduler=None) -> Optional[int]:
if not os.path.exists(path):
return None
payload = torch.load(path, map_location="cpu")
model.load_state_dict(payload["model_state_dict"], strict=False)
if optimizer is not None and payload.get("optimizer_state_dict"):
optimizer.load_state_dict(payload["optimizer_state_dict"])
if scheduler is not None and payload.get("scheduler_state_dict"):
scheduler.load_state_dict(payload["scheduler_state_dict"])
return int(payload.get("epoch", 0))