image2painting / checkpoint.py
Lasercatz
Upload 9 files
97bca33 verified
import os
import torch
from typing import Optional
class CheckpointManager:
def __init__(self, relative_dir: str = "checkpoints"):
"""
Args:
ckpt_dir (str): Directory where checkpoints are stored.
"""
base_dir = os.path.dirname(os.path.abspath(__file__))
self.ckpt_dir = os.path.join(base_dir, relative_dir)
os.makedirs(self.ckpt_dir, exist_ok=True)
def _format_filename(self, epoch: int, last_sample_idx:int) -> str:
return os.path.join(self.ckpt_dir, f"epoch_{epoch}_sample_{last_sample_idx}.pth")
def save(self, model, scaler, optimizer, scheduler, epoch, last_sample_idx):
state = {
"epoch": epoch,
"last_sample_idx": last_sample_idx,
"model": model.state_dict(),
"scaler": scaler.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}
filename = self._format_filename(epoch,last_sample_idx)
torch.save(state, filename)
print(f"[checkpoint.py] Saved checkpoint to {filename}")
def load(self, model, scaler=None, optimizer=None, scheduler=None, filename=None):
if filename is None:
# Load latest checkpoint
files = sorted([f for f in os.listdir(self.ckpt_dir) if f.endswith(".pth")])
if not files:
print(f"No checkpoints found in {self.ckpt_dir}")
return 0, 0
filename = os.path.join(self.ckpt_dir, files[-1])
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint["model"])
if scaler:
scaler.load_state_dict(checkpoint["scaler"])
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer"])
if scheduler:
scheduler.load_state_dict(checkpoint["scheduler"])
print(f"[checkpoint.py] Loaded checkpoint from {filename}")
return checkpoint["epoch"], checkpoint["last_sample_idx"]