import torch from factories import UNet_conditional net = UNet_conditional(num_classes=768,device="mps") net.to("mps") net.load_state_dict(torch.load("runs/run_3_jxa/ckpt/latest.pt")) net.to("cpu") torch.save(net.state_dict(), "runs/run_3_jxa/ckpt/latest_cpu.pt")