File size: 6,318 Bytes
658e26c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# pyright: reportMissingModuleSource=false
import numpy as np
from augly.image import functional as aug_functional
import torch
from torchvision import transforms
from torchvision.transforms import functional
from torch.autograd.variable import Variable
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
default_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
normalize_vqgan = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize (x - 0.5) / 0.5
unnormalize_vqgan = transforms.Normalize(mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5]) # Unnormalize (x * 0.5) + 0.5
normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize (x - mean) / std
unnormalize_img = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) # Unnormalize (x * std) + mean
def psnr(x, y, img_space='vqgan'):
"""
Return PSNR
Args:
x: Image tensor with values approx. between [-1,1]
y: Image tensor with values approx. between [-1,1], ex: original image
"""
if img_space == 'vqgan':
delta = torch.clamp(unnormalize_vqgan(x), 0, 1) - torch.clamp(unnormalize_vqgan(y), 0, 1)
elif img_space == 'img':
delta = torch.clamp(unnormalize_img(x), 0, 1) - torch.clamp(unnormalize_img(y), 0, 1)
else:
delta = x - y
delta = 255 * delta
delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B
return psnr
def center_crop(x, scale):
""" Perform center crop such that the target area of the crop is at a given scale
Args:
x: PIL image
scale: target area scale
"""
scale = np.sqrt(scale)
new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1]
return functional.center_crop(x, new_edges_size)
def resize(x, scale):
""" Perform center crop such that the target area of the crop is at a given scale
Args:
x: PIL image
scale: target area scale
"""
scale = np.sqrt(scale)
new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1]
return functional.resize(x, new_edges_size)
def rotate(x, angle):
""" Rotate image by angle
Args:
x: image (PIl or tensor)
angle: angle in degrees
"""
return functional.rotate(x, angle)
def flip(x, direction='horizontal'):
""" Rotate image by angle
Args:
x: image (PIl or tensor)
angle: angle in degrees
"""
if direction == 'horizontal':
return functional.hflip(x)
elif direction == 'vertical':
return functional.vflip(x)
def adjust_brightness(x, brightness_factor):
""" Adjust brightness of an image
Args:
x: PIL image
brightness_factor: brightness factor
"""
return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor))
def adjust_contrast(x, contrast_factor):
""" Adjust contrast of an image
Args:
x: PIL image
contrast_factor: contrast factor
"""
return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor))
def adjust_saturation(x, saturation_factor):
""" Adjust saturation of an image
Args:
x: PIL image
saturation_factor: saturation factor
"""
return normalize_img(functional.adjust_saturation(unnormalize_img(x), saturation_factor))
def adjust_hue(x, hue_factor):
""" Adjust hue of an image
Args:
x: PIL image
hue_factor: hue factor
"""
return normalize_img(functional.adjust_hue(unnormalize_img(x), hue_factor))
def adjust_gamma(x, gamma, gain=1):
""" Adjust gamma of an image
Args:
x: PIL image
gamma: gamma factor
gain: gain factor
"""
return normalize_img(functional.adjust_gamma(unnormalize_img(x), gamma, gain))
def adjust_sharpness(x, sharpness_factor):
""" Adjust sharpness of an image
Args:
x: PIL image
sharpness_factor: sharpness factor
"""
return normalize_img(functional.adjust_sharpness(unnormalize_img(x), sharpness_factor))
def overlay_text(x, text='Lorem Ipsum'):
""" Overlay text on image
Args:
x: PIL image
text: text to overlay
font_path: path to font
font_size: font size
color: text color
position: text position
"""
to_pil = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
img_aug = torch.zeros_like(x, device=x.device)
for ii,img in enumerate(x):
pil_img = to_pil(unnormalize_img(img))
img_aug[ii] = to_tensor(aug_functional.overlay_text(pil_img, text=text))
return normalize_img(img_aug)
def jpeg_compress(x, quality_factor):
""" Apply jpeg compression to image
Args:
x: PIL image
quality_factor: quality factor
"""
to_pil = transforms.ToPILImage()
to_tensor = transforms.ToTensor()
img_aug = torch.zeros_like(x, device=x.device)
for ii,img in enumerate(x):
pil_img = to_pil(unnormalize_img(img))
img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor))
return normalize_img(img_aug)
def gaussian_noise(input, stddev):
# noise = Variable(input.data.new(input.size()).normal_(0, stddev))
# output = torch.clamp(input + noise, -1, 1)
# output = A.GaussNoise(var_limit=stddev, p=1)
output = torch.clamp(unnormalize_img(input).clone() + (torch.randn(
[input.shape[0], input.shape[1], input.shape[2], input.shape[3]]) * (stddev**0.5)).to(input.device), -1, 1)
return normalize_img(output)
def adjust_gaussian_blur(img, ks):
return normalize_img(functional.gaussian_blur(unnormalize_img(img), kernel_size=ks)) |