LYL1015's picture
test
eea83e8
import torch
import shutil
import os
import torchvision.utils as tvu
def save_image(img, file_directory):
if not os.path.exists(os.path.dirname(file_directory)):
os.makedirs(os.path.dirname(file_directory))
tvu.save_image(img, file_directory)
def save_checkpoint(state, filename):
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
torch.save(state, filename + '.pth.tar')
def load_checkpoint(path, device):
if device is None:
return torch.load(path)
else:
return torch.load(path, map_location=device)