text-to-image-model / model_to_cpu.py
JBlitzar
commit
f86c7c7
import torch
from factories import UNet_conditional
net = UNet_conditional(num_classes=768,device="mps")
net.to("mps")
net.load_state_dict(torch.load("runs/run_3_jxa/ckpt/latest.pt"))
net.to("cpu")
torch.save(net.state_dict(), "runs/run_3_jxa/ckpt/latest_cpu.pt")