| from config import config
|
| from torchvision import transforms
|
| import cv2 as cv
|
| import torchvision.transforms.functional as TF
|
| import random
|
|
|
|
|
| class JointTransformMethod:
|
| def __call__(self, img, label):
|
| img = transforms.ToPILImage()(img).convert('L')
|
| label = transforms.ToPILImage()(label).convert('L')
|
|
|
| if random.random() > 0.5:
|
| img = TF.hflip(img)
|
| label = TF.hflip(label)
|
|
|
| if random.random() > 0.2:
|
| i, j, h, w = transforms.RandomResizedCrop.get_params(
|
| img, scale=(0.8, 1.0), ratio=(0.9, 1.1))
|
| img = TF.resized_crop(img, i, j, h, w, (config.image_size, config.image_size))
|
| label = TF.resized_crop(label, i, j, h, w, (config.image_size, config.image_size))
|
| else:
|
| img = TF.resize(img, (config.image_size, config.image_size))
|
| label = TF.resize(label, (config.image_size, config.image_size))
|
|
|
| img = TF.to_tensor(img)
|
| label = TF.to_tensor(label)
|
|
|
| img = (img - 0.5) / 0.5
|
| label = (label - 0.5) / 0.5
|
|
|
| return img, label
|
|
|
|
|
| class TestTransformMethod:
|
| def __call__(self, img):
|
| img = cv.resize(img, (config.image_size, config.image_size))
|
| if len(img.shape) == 2:
|
| img = img[:, :, None]
|
|
|
| img = transforms.ToTensor()(img)
|
| img = (img - 0.5) / 0.5
|
| return img
|
|
|
|
|
| myDiTTransform = {
|
| 'trainTransform': JointTransformMethod(),
|
| 'testTransform': TestTransformMethod()
|
| }
|
|
|
|
|
| class myTransformMethod():
|
| def __call__(self, img):
|
| img = cv.resize(img, (config.image_size, config.image_size))
|
| if img.shape[-1] == 3:
|
| img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
|
| return img
|
|
|
|
|
| myTransform = {
|
| 'trainTransform': transforms.Compose([
|
| myTransformMethod(),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5], [0.5])
|
| ]),
|
| 'testTransform': transforms.Compose([
|
| myTransformMethod(),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5], [0.5])
|
| ]),
|
|
|
| }
|
|
|