Spaces:
Runtime error
Runtime error
File size: 602 Bytes
eea83e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import torch
import shutil
import os
import torchvision.utils as tvu
def save_image(img, file_directory):
if not os.path.exists(os.path.dirname(file_directory)):
os.makedirs(os.path.dirname(file_directory))
tvu.save_image(img, file_directory)
def save_checkpoint(state, filename):
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
torch.save(state, filename + '.pth.tar')
def load_checkpoint(path, device):
if device is None:
return torch.load(path)
else:
return torch.load(path, map_location=device)
|