| import random |
| import numpy as np |
| import torch |
| from torchvision import transforms |
| from torchvision.transforms import functional as F |
| from torch.nn.functional import pad |
|
|
|
|
| class RITE_Transform(): |
| def __init__(self, config): |
| self.pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1,1,1) |
| self.pixel_std = torch.Tensor([53.395, 57.12, 57.375]).view(-1,1,1) |
| self.degree = config['data_transforms']['rotation_angle'] |
| self.saturation = config['data_transforms']['saturation'] |
| self.brightness = config['data_transforms']['brightness'] |
| self.img_size = config['data_transforms']['img_size'] |
| self.resize = transforms.Resize(self.img_size-1, max_size=self.img_size, antialias=True) |
|
|
| self.data_transforms = config['data_transforms'] |
|
|
| def __call__(self, img, mask, apply_norm=True, is_train=True): |
| if is_train: |
| |
| if self.data_transforms['use_horizontal_flip']: |
| p = random.random() |
| if p<0.5: |
| img = F.hflip(img) |
| mask = F.hflip(mask) |
|
|
| |
| if self.data_transforms['use_rotation']: |
| p = random.random() |
| if p<0.5: |
| deg = 1+random.choice(list(range(self.degree))) |
| img = F.rotate(img, angle = deg) |
| mask = F.rotate(mask, angle=deg) |
|
|
| |
| if self.data_transforms['use_saturation']: |
| p = random.random() |
| if p<0.2: |
| img = F.adjust_saturation(img, self.saturation) |
| |
| |
| if self.data_transforms['use_brightness']: |
| p = random.random() |
| if p<0.5: |
| img = F.adjust_brightness(img, self.brightness*max(0.5,random.random())) |
|
|
| |
| if self.data_transforms['use_random_crop']: |
| fallback = 20 |
| fall_back_ctr = 0 |
| repeat_flag = True |
| while(repeat_flag): |
| fall_back_ctr += 1 |
| t = transforms.RandomCrop((self.img_size, self.img_size)) |
| i,j,h,w = t.get_params(img, (self.img_size, self.img_size)) |
| |
| |
| if not mask.any(): |
| repeat_flag = False |
| |
| |
| if fall_back_ctr >= fallback: |
| temp1, temp2, temp3 = np.where(mask!=0) |
| point_of_interest = random.choice(list(range(len(temp2)))) |
| i = temp2[point_of_interest] - (h//2) |
| j = temp3[point_of_interest] - (w//2) |
| repeat_flag = False |
|
|
| cropped_img = F.crop(img, i, j, h, w) |
| cropped_mask = F.crop(mask, i, j, h, w) |
| if cropped_mask.any(): |
| repeat_flag = False |
| img = cropped_img |
| mask = cropped_mask |
| else: |
| |
| b_min = 0 |
| img = self.resize(img) |
| mask = self.resize(mask) |
| |
| h, w = img.shape[-2:] |
| padh = self.img_size - h |
| padw = self.img_size - w |
| img = pad(img, (0, padw, 0, padh), value=b_min) |
| mask = pad(mask, (0, padw, 0, padh), value=b_min) |
|
|
|
|
| |
| if apply_norm: |
| b_min=0 |
| |
| b_min,b_max = 0, 255 |
| img = (img - self.data_transforms['a_min']) / (self.data_transforms['a_max'] - self.data_transforms['a_min']) |
| img = img * (b_max - b_min) + b_min |
| img = torch.clamp(img,b_min,b_max) |
|
|
| |
| img = (img - self.pixel_mean)/self.pixel_std |
| |
| return img, mask |