|
|
"""
|
|
|
Created on 2020/9/8
|
|
|
|
|
|
@author: Boyun Li
|
|
|
"""
|
|
|
import os
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import random
|
|
|
import torch.nn as nn
|
|
|
from torch.nn import init
|
|
|
from PIL import Image
|
|
|
|
|
|
class EdgeComputation(nn.Module):
|
|
|
def __init__(self, test=False):
|
|
|
super(EdgeComputation, self).__init__()
|
|
|
self.test = test
|
|
|
def forward(self, x):
|
|
|
if self.test:
|
|
|
x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
|
|
|
x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
|
|
|
|
|
|
|
|
|
y = torch.Tensor(x.size())
|
|
|
y.fill_(0)
|
|
|
y[:, :, :, 1:] += x_diffx
|
|
|
y[:, :, :, :-1] += x_diffx
|
|
|
y[:, :, 1:, :] += x_diffy
|
|
|
y[:, :, :-1, :] += x_diffy
|
|
|
y = torch.sum(y, 1, keepdim=True) / 3
|
|
|
y /= 4
|
|
|
return y
|
|
|
else:
|
|
|
x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1])
|
|
|
x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :])
|
|
|
|
|
|
y = torch.Tensor(x.size())
|
|
|
y.fill_(0)
|
|
|
y[:, :, 1:] += x_diffx
|
|
|
y[:, :, :-1] += x_diffx
|
|
|
y[:, 1:, :] += x_diffy
|
|
|
y[:, :-1, :] += x_diffy
|
|
|
y = torch.sum(y, 0) / 3
|
|
|
y /= 4
|
|
|
return y.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
def crop_patch(im, pch_size):
|
|
|
H = im.shape[0]
|
|
|
W = im.shape[1]
|
|
|
ind_H = random.randint(0, H - pch_size)
|
|
|
ind_W = random.randint(0, W - pch_size)
|
|
|
pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size]
|
|
|
return pch
|
|
|
|
|
|
|
|
|
|
|
|
def crop_img(image, base=64):
|
|
|
h = image.shape[0]
|
|
|
w = image.shape[1]
|
|
|
crop_h = h % base
|
|
|
crop_w = w % base
|
|
|
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
|
|
|
|
|
|
|
|
|
|
|
|
def slice_image2patches(image, patch_size=64, overlap=0):
|
|
|
assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0
|
|
|
H = image.shape[0]
|
|
|
W = image.shape[1]
|
|
|
patches = []
|
|
|
image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge')
|
|
|
for h in range(H // patch_size):
|
|
|
for w in range(W // patch_size):
|
|
|
idx_h = [h * patch_size, (h + 1) * patch_size + overlap]
|
|
|
idx_w = [w * patch_size, (w + 1) * patch_size + overlap]
|
|
|
patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0))
|
|
|
return np.concatenate(patches, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
def splice_patches2image(patches, image_size, overlap=0):
|
|
|
assert len(image_size) > 1
|
|
|
assert patches.shape[-3] == patches.shape[-2]
|
|
|
H = image_size[0]
|
|
|
W = image_size[1]
|
|
|
patch_size = patches.shape[-2] - overlap
|
|
|
image = np.zeros(image_size)
|
|
|
idx = 0
|
|
|
for h in range(H // patch_size):
|
|
|
for w in range(W // patch_size):
|
|
|
image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx,
|
|
|
overlap:patch_size + overlap,
|
|
|
overlap:patch_size + overlap,
|
|
|
:]
|
|
|
idx += 1
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def data_augmentation(image, mode):
|
|
|
if mode == 0:
|
|
|
|
|
|
out = image.numpy()
|
|
|
elif mode == 1:
|
|
|
|
|
|
out = np.flipud(image)
|
|
|
elif mode == 2:
|
|
|
|
|
|
out = np.rot90(image)
|
|
|
elif mode == 3:
|
|
|
|
|
|
out = np.rot90(image)
|
|
|
out = np.flipud(out)
|
|
|
elif mode == 4:
|
|
|
|
|
|
out = np.rot90(image, k=2)
|
|
|
elif mode == 5:
|
|
|
|
|
|
out = np.rot90(image, k=2)
|
|
|
out = np.flipud(out)
|
|
|
elif mode == 6:
|
|
|
|
|
|
out = np.rot90(image, k=3)
|
|
|
elif mode == 7:
|
|
|
|
|
|
out = np.rot90(image, k=3)
|
|
|
out = np.flipud(out)
|
|
|
else:
|
|
|
raise Exception('Invalid choice of image transformation')
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def random_augmentation(*args):
|
|
|
out = []
|
|
|
flag_aug = random.randint(1, 7)
|
|
|
for data in args:
|
|
|
out.append(data_augmentation(data, flag_aug).copy())
|
|
|
return out
|
|
|
|
|
|
|
|
|
def weights_init_normal_(m):
|
|
|
classname = m.__class__.__name__
|
|
|
if classname.find('Conv') != -1:
|
|
|
init.uniform(m.weight.data, 0.0, 0.02)
|
|
|
elif classname.find('Linear') != -1:
|
|
|
init.uniform(m.weight.data, 0.0, 0.02)
|
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
|
init.uniform(m.weight.data, 1.0, 0.02)
|
|
|
init.constant(m.bias.data, 0.0)
|
|
|
|
|
|
|
|
|
def weights_init_normal(m):
|
|
|
classname = m.__class__.__name__
|
|
|
if classname.find('Conv2d') != -1:
|
|
|
m.apply(weights_init_normal_)
|
|
|
elif classname.find('Linear') != -1:
|
|
|
init.uniform(m.weight.data, 0.0, 0.02)
|
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
|
init.uniform(m.weight.data, 1.0, 0.02)
|
|
|
init.constant(m.bias.data, 0.0)
|
|
|
|
|
|
|
|
|
def weights_init_xavier(m):
|
|
|
classname = m.__class__.__name__
|
|
|
if classname.find('Conv') != -1:
|
|
|
init.xavier_normal(m.weight.data, gain=1)
|
|
|
elif classname.find('Linear') != -1:
|
|
|
init.xavier_normal(m.weight.data, gain=1)
|
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
|
init.uniform(m.weight.data, 1.0, 0.02)
|
|
|
init.constant(m.bias.data, 0.0)
|
|
|
|
|
|
|
|
|
def weights_init_kaiming(m):
|
|
|
classname = m.__class__.__name__
|
|
|
if classname.find('Conv') != -1:
|
|
|
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
|
|
|
elif classname.find('Linear') != -1:
|
|
|
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
|
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
|
init.uniform(m.weight.data, 1.0, 0.02)
|
|
|
init.constant(m.bias.data, 0.0)
|
|
|
|
|
|
|
|
|
def weights_init_orthogonal(m):
|
|
|
classname = m.__class__.__name__
|
|
|
print(classname)
|
|
|
if classname.find('Conv') != -1:
|
|
|
init.orthogonal(m.weight.data, gain=1)
|
|
|
elif classname.find('Linear') != -1:
|
|
|
init.orthogonal(m.weight.data, gain=1)
|
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
|
init.uniform(m.weight.data, 1.0, 0.02)
|
|
|
init.constant(m.bias.data, 0.0)
|
|
|
|
|
|
|
|
|
def init_weights(net, init_type='normal'):
|
|
|
print('initialization method [%s]' % init_type)
|
|
|
if init_type == 'normal':
|
|
|
net.apply(weights_init_normal)
|
|
|
elif init_type == 'xavier':
|
|
|
net.apply(weights_init_xavier)
|
|
|
elif init_type == 'kaiming':
|
|
|
net.apply(weights_init_kaiming)
|
|
|
elif init_type == 'orthogonal':
|
|
|
net.apply(weights_init_orthogonal)
|
|
|
else:
|
|
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
|
|
|
|
|
|
|
|
def np_to_torch(img_np):
|
|
|
"""
|
|
|
Converts image in numpy.array to torch.Tensor.
|
|
|
|
|
|
From C x W x H [0..1] to C x W x H [0..1]
|
|
|
|
|
|
:param img_np:
|
|
|
:return:
|
|
|
"""
|
|
|
return torch.from_numpy(img_np)[None, :]
|
|
|
|
|
|
|
|
|
def torch_to_np(img_var):
|
|
|
"""
|
|
|
Converts an image in torch.Tensor format to np.array.
|
|
|
|
|
|
From 1 x C x W x H [0..1] to C x W x H [0..1]
|
|
|
:param img_var:
|
|
|
:return:
|
|
|
"""
|
|
|
return img_var.detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
def save_image(name, image_np, output_path="output/normal/"):
|
|
|
if not os.path.exists(output_path):
|
|
|
os.mkdir(output_path)
|
|
|
|
|
|
p = np_to_pil(image_np)
|
|
|
p.save(output_path + "{}.png".format(name))
|
|
|
|
|
|
|
|
|
def np_to_pil(img_np):
|
|
|
"""
|
|
|
Converts image in np.array format to PIL image.
|
|
|
|
|
|
From C x W x H [0..1] to W x H x C [0...255]
|
|
|
:param img_np:
|
|
|
:return:
|
|
|
"""
|
|
|
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
|
|
|
|
|
|
if img_np.shape[0] == 1:
|
|
|
ar = ar[0]
|
|
|
else:
|
|
|
assert img_np.shape[0] == 3, img_np.shape
|
|
|
ar = ar.transpose(1, 2, 0)
|
|
|
|
|
|
return Image.fromarray(ar) |