| import torch | |
| from factories import UNet_conditional | |
| net = UNet_conditional(num_classes=768,device="mps") | |
| net.to("mps") | |
| net.load_state_dict(torch.load("runs/run_3_jxa/ckpt/latest.pt")) | |
| net.to("cpu") | |
| torch.save(net.state_dict(), "runs/run_3_jxa/ckpt/latest_cpu.pt") | |