| import os | |
| import torch | |
| from typing import Optional | |
| def save_checkpoint(checkpoint_dir: str, epoch: int, model, optimizer, scheduler, val_loss: float, config) -> str: | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt") | |
| payload = { | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict() if hasattr(model, "state_dict") else model.module.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None, | |
| "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None, | |
| "val_loss": val_loss, | |
| "config": config, | |
| } | |
| torch.save(payload, path) | |
| return path | |
| def load_checkpoint(path: str, model, optimizer=None, scheduler=None) -> Optional[int]: | |
| if not os.path.exists(path): | |
| return None | |
| payload = torch.load(path, map_location="cpu") | |
| model.load_state_dict(payload["model_state_dict"], strict=False) | |
| if optimizer is not None and payload.get("optimizer_state_dict"): | |
| optimizer.load_state_dict(payload["optimizer_state_dict"]) | |
| if scheduler is not None and payload.get("scheduler_state_dict"): | |
| scheduler.load_state_dict(payload["scheduler_state_dict"]) | |
| return int(payload.get("epoch", 0)) | |