Spaces:
Paused
Paused
File size: 1,636 Bytes
fab18b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | """Checkpoint save/load helpers."""
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import Optimizer
def save_checkpoint(
generator: nn.Module,
discriminator: nn.Module,
opt_g: Optimizer,
opt_d: Optimizer,
epoch: int,
path: str | Path,
) -> None:
path = Path(path)
torch.save(
{
"epoch": epoch,
"generator_state": generator.state_dict(),
"discriminator_state": discriminator.state_dict(),
"opt_g_state": opt_g.state_dict(),
"opt_d_state": opt_d.state_dict(),
},
path,
)
print(f"[Checkpoint] Saved → {path}")
def load_checkpoint(
path: str | Path,
generator: nn.Module,
discriminator: nn.Module,
opt_g: Optimizer,
opt_d: Optimizer,
device: torch.device | str = "cpu",
) -> int:
ckpt = torch.load(path, map_location=device)
generator.load_state_dict(ckpt["generator_state"])
discriminator.load_state_dict(ckpt["discriminator_state"])
opt_g.load_state_dict(ckpt["opt_g_state"])
opt_d.load_state_dict(ckpt["opt_d_state"])
epoch = ckpt.get("epoch", 0)
print(f"[Checkpoint] Loaded ← {path} (epoch {epoch})")
return epoch
def load_generator_only(
path: str | Path,
generator: nn.Module,
device: torch.device | str = "cpu",
) -> nn.Module:
"""Load only the generator weights – used at inference time."""
ckpt = torch.load(path, map_location=device)
generator.load_state_dict(ckpt["generator_state"])
generator.eval()
print(f"[Checkpoint] Generator loaded ← {path}")
return generator
|