File size: 1,181 Bytes
24870a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | import os
import cv2
import numpy as np
from PIL import Image
import torch
def gen_noise(shape):
noise = np.zeros(shape, dtype=np.uint8)
### noise
noise = cv2.randn(noise, 0, 255)
noise = np.asarray(noise / 255, dtype=np.uint8)
noise = torch.tensor(noise, dtype=torch.float32)
return noise
def save_images(img_tensors, img_names, save_dir):
for img_tensor, img_name in zip(img_tensors, img_names):
tensor = (img_tensor.clone()+1)*0.5 * 255
tensor = tensor.cpu().clamp(0,255)
try:
array = tensor.numpy().astype('uint8')
except:
array = tensor.detach().numpy().astype('uint8')
if array.shape[0] == 1:
array = array.squeeze(0)
elif array.shape[0] == 3:
array = array.swapaxes(0, 1).swapaxes(1, 2)
im = Image.fromarray(array)
im.save(os.path.join(save_dir, img_name), format='JPEG')
def load_checkpoint(model, checkpoint_path):
if not os.path.exists(checkpoint_path):
raise ValueError("'{}' is not a valid checkpoint path".format(checkpoint_path))
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|