File size: 269 Bytes
f86c7c7
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from factories import UNet_conditional



net = UNet_conditional(num_classes=768,device="mps")
net.to("mps")
net.load_state_dict(torch.load("runs/run_3_jxa/ckpt/latest.pt"))

net.to("cpu")
torch.save(net.state_dict(), "runs/run_3_jxa/ckpt/latest_cpu.pt")