| | |
| | |
| |
|
| | import torch |
| | import torchvision.transforms |
| | import torchvision.transforms.functional as F |
| |
|
| | |
| | |
| |
|
| | class ComposePair(torchvision.transforms.Compose): |
| | def __call__(self, img1, img2): |
| | for t in self.transforms: |
| | img1, img2 = t(img1, img2) |
| | return img1, img2 |
| |
|
| | class NormalizeBoth(torchvision.transforms.Normalize): |
| | def forward(self, img1, img2): |
| | img1 = super().forward(img1) |
| | img2 = super().forward(img2) |
| | return img1, img2 |
| |
|
| | class ToTensorBoth(torchvision.transforms.ToTensor): |
| | def __call__(self, img1, img2): |
| | img1 = super().__call__(img1) |
| | img2 = super().__call__(img2) |
| | return img1, img2 |
| | |
| | class RandomCropPair(torchvision.transforms.RandomCrop): |
| | |
| | def forward(self, img1, img2): |
| | img1 = super().forward(img1) |
| | img2 = super().forward(img2) |
| | return img1, img2 |
| |
|
| | class ColorJitterPair(torchvision.transforms.ColorJitter): |
| | |
| | def __init__(self, assymetric_prob, **kwargs): |
| | super().__init__(**kwargs) |
| | self.assymetric_prob = assymetric_prob |
| | def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor): |
| | for fn_id in fn_idx: |
| | if fn_id == 0 and brightness_factor is not None: |
| | img = F.adjust_brightness(img, brightness_factor) |
| | elif fn_id == 1 and contrast_factor is not None: |
| | img = F.adjust_contrast(img, contrast_factor) |
| | elif fn_id == 2 and saturation_factor is not None: |
| | img = F.adjust_saturation(img, saturation_factor) |
| | elif fn_id == 3 and hue_factor is not None: |
| | img = F.adjust_hue(img, hue_factor) |
| | return img |
| | |
| | def forward(self, img1, img2): |
| |
|
| | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( |
| | self.brightness, self.contrast, self.saturation, self.hue |
| | ) |
| | img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) |
| | if torch.rand(1) < self.assymetric_prob: |
| | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( |
| | self.brightness, self.contrast, self.saturation, self.hue |
| | ) |
| | img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) |
| | return img1, img2 |
| |
|
| | def get_pair_transforms(transform_str, totensor=True, normalize=True): |
| | |
| | trfs = [] |
| | for s in transform_str.split('+'): |
| | if s.startswith('crop'): |
| | size = int(s[len('crop'):]) |
| | trfs.append(RandomCropPair(size)) |
| | elif s=='acolor': |
| | trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0)) |
| | elif s=='': |
| | pass |
| | else: |
| | raise NotImplementedError('Unknown augmentation: '+s) |
| | |
| | if totensor: |
| | trfs.append( ToTensorBoth() ) |
| | if normalize: |
| | trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) |
| |
|
| | if len(trfs)==0: |
| | return None |
| | elif len(trfs)==1: |
| | return trfs |
| | else: |
| | return ComposePair(trfs) |
| | |
| | |
| | |
| | |
| | |
| |
|