File size: 269 Bytes
f86c7c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch
from factories import UNet_conditional
net = UNet_conditional(num_classes=768,device="mps")
net.to("mps")
net.load_state_dict(torch.load("runs/run_3_jxa/ckpt/latest.pt"))
net.to("cpu")
torch.save(net.state_dict(), "runs/run_3_jxa/ckpt/latest_cpu.pt")
|