| | import torch |
| | import torchvision |
| |
|
| | import math |
| | import cv2 |
| | import numpy as np |
| | from scipy.ndimage import rotate |
| |
|
| |
|
| | class RandCrop(object): |
| | def __init__(self, crop_size, scale): |
| | |
| | assert isinstance(crop_size, (int, tuple)) |
| | if isinstance(crop_size, int): |
| | self.crop_size = (crop_size, crop_size) |
| | else: |
| | assert len(crop_size) == 2 |
| | self.crop_size = crop_size |
| | |
| | self.scale = scale |
| |
|
| | def __call__(self, sample): |
| | |
| | img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
| |
|
| | h, w, c = img_LQ.shape |
| | new_h, new_w = self.crop_size |
| | top = np.random.randint(0, h - new_h) |
| | left = np.random.randint(0, w - new_w) |
| | img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :] |
| |
|
| | h, w, c = img_GT.shape |
| | top = np.random.randint(0, h - self.scale*new_h) |
| | left = np.random.randint(0, w - self.scale*new_w) |
| | img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :] |
| |
|
| | sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop} |
| | return sample |
| |
|
| |
|
| | class RandRotate(object): |
| | def __call__(self, sample): |
| | |
| | img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
| |
|
| | prob_rotate = np.random.random() |
| | if prob_rotate < 0.25: |
| | img_LQ = rotate(img_LQ, 90).copy() |
| | img_GT = rotate(img_GT, 90).copy() |
| | elif prob_rotate < 0.5: |
| | img_LQ = rotate(img_LQ, 90).copy() |
| | img_GT = rotate(img_GT, 90).copy() |
| | elif prob_rotate < 0.75: |
| | img_LQ = rotate(img_LQ, 90).copy() |
| | img_GT = rotate(img_GT, 90).copy() |
| | |
| | sample = {'img_LQ': img_LQ, 'img_GT': img_GT} |
| | return sample |
| |
|
| |
|
| | class RandHorizontalFlip(object): |
| | def __call__(self, sample): |
| | |
| | img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
| |
|
| | prob_lr = np.random.random() |
| | if prob_lr < 0.5: |
| | img_LQ = np.fliplr(img_LQ).copy() |
| | img_GT = np.fliplr(img_GT).copy() |
| | |
| | sample = {'img_LQ': img_LQ, 'img_GT': img_GT} |
| | return sample |
| |
|
| |
|
| | class ToTensor(object): |
| | def __call__(self, sample): |
| | |
| | img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
| |
|
| | img_LQ = img_LQ.transpose((2, 0, 1)) |
| | img_GT = img_GT.transpose((2, 0, 1)) |
| |
|
| | img_LQ = torch.from_numpy(img_LQ) |
| | img_GT = torch.from_numpy(img_GT) |
| |
|
| | sample = {'img_LQ': img_LQ, 'img_GT': img_GT} |
| | return sample |
| |
|
| |
|
| | class VGG19PerceptualLoss(torch.nn.Module): |
| | def __init__(self, feature_layer=35): |
| | super(VGG19PerceptualLoss, self).__init__() |
| | model = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT) |
| | self.features = torch.nn.Sequential(*list(model.features.children())[:feature_layer]).eval() |
| | |
| | for name, param in self.features.named_parameters(): |
| | param.requires_grad = False |
| | |
| | def forward(self, source, target): |
| | vgg_loss = torch.nn.functional.l1_loss(self.features(source), self.features(target)) |
| |
|
| | return vgg_loss |
| | |
| |
|
| | class RandCrop_pair(object): |
| | def __init__(self, crop_size, scale): |
| | |
| | assert isinstance(crop_size, (int, tuple)) |
| | if isinstance(crop_size, int): |
| | self.crop_size = (crop_size, crop_size) |
| | else: |
| | assert len(crop_size) == 2 |
| | self.crop_size = crop_size |
| | |
| | self.scale = scale |
| |
|
| | def __call__(self, sample): |
| | |
| | img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
| |
|
| | h, w, c = img_LQ.shape |
| | new_h, new_w = self.crop_size |
| | top = np.random.randint(0, h - new_h) |
| | left = np.random.randint(0, w - new_w) |
| | img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :] |
| |
|
| | h, w, c = img_GT.shape |
| | top = self.scale*top |
| | left = self.scale*left |
| | img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :] |
| |
|
| | sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop} |
| | return sample |