import cv2 from torchvision import transforms import numpy as np import torch def re_normalize(image_tensor, old='[-1,1]', new='imagenet'): """ Re-normalizes an image tensor from one normalization scheme to another. Args: image_tensor (torch.Tensor): Image tensor to be re-normalized. old (str): Old normalization scheme. Options: '[-1,1]', 'imagenet'. new (str): New normalization scheme. Options: '[-1,1]', 'imagenet'. Returns: torch.Tensor: Re-normalized image tensor. """ # Old normalization parameters device = image_tensor.device if old == '[-1,1]': old_mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) old_std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) elif old == 'imagenet': old_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) old_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) elif old == '[0,1]': old_mean = torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1).to(device) old_std = torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1).to(device) else: print('old normalization not implemented') raise NotImplementedError # New normalization parameters if new == '[-1,1]': new_mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) new_std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device) elif new == 'imagenet': new_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) new_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) elif new == '[0,1]': new_mean = torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1).to(device) new_std = torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1).to(device) else: print('new normalization not implemented') raise NotImplementedError # Step 1: Denormalize the image tensor using the old mean and std denormalized_image = image_tensor * old_std + old_mean # Step 2: Normalize the image tensor using the new mean and std normalized_image = (denormalized_image - new_mean) / new_std return normalized_image def wrap_transforms(image_transforms_type, image_size): if image_transforms_type == 'basic_imagenet': MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] return transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) else: raise NotImplementedError # def enhance_contrast_clahe(image): # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) # lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) # lab_planes = list( cv2.split(lab) ) # lab_planes[0] = clahe.apply(lab_planes[0]) # lab = cv2.merge(lab_planes) # image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) # return image