AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
'''
Copyright 2024 Image Processing Research Group of University Federico
II of Naples ('GRIP-UNINA'). All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
import random
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.v2 as Tv2
from PIL import Image
import torch
def make_processing(opt):
opt = parse_arguments(opt)
transforms_list = list() # list of transforms
if opt.task == 'train':
transforms_aug = make_aug(opt) # make data-augmentation transforms
if transforms_aug is not None:
transforms_list.append(transforms_aug)
transforms_post = make_post(opt) # make post-data-augmentation transforms
if transforms_post is not None:
transforms_list.append(transforms_post)
if opt.task == 'test' and 'realFORLAB:pre' in opt.data_keys:
transforms_list.append(Tv2.CenterCrop(1024))
if opt.task == 'test' and 'realFORLAB:fb' in opt.data_keys:
transforms_list.append(Tv2.CenterCrop(720))
if opt.task == 'test' and 'realFORLAB:tw' in opt.data_keys:
transforms_list.append(Tv2.CenterCrop(1200))
if opt.task == 'test' and 'realFORLAB:tl' in opt.data_keys:
transforms_list.append(Tv2.CenterCrop(800))
transforms_list.append(make_normalize(opt)) # make normalization
return Tv2.Compose(transforms_list)
def add_processing_arguments(parser):
# parser is an argparse.ArgumentParser
#
# ICASSP2023: --cropSize 96 --loadSize -1 --resizeSize -1 --norm_type resnet --resize_prob 0.2 --jitter_prob 0.8 --colordist_prob 0.2 --cutout_prob 0.2 --noise_prob 0.2 --blur_prob 0.5 --cmp_prob 0.5 --rot90_prob 1.0 --hpf_prob 0.0 --blur_sig 0.0,3.0 --cmp_method cv2,pil --cmp_qual 30,100 --resize_size 256 --resize_ratio 0.75
#
parser.add_argument("--cropSize",type=int,default=-1,help="crop images to this size post augumentation")
# data-augmentation probabilities
parser.add_argument("--resize_prob", type=float, default=0.0)
parser.add_argument("--jitter_prob", type=float, default=0.0)
parser.add_argument("--colordist_prob", type=float, default=0.0)
parser.add_argument("--cutout_prob", type=float, default=0.0)
parser.add_argument("--noise_prob", type=float, default=0.0)
parser.add_argument("--blur_prob", type=float, default=0.0)
parser.add_argument("--cmp_prob", type=float, default=0.0)
# data-augmentation parameters
parser.add_argument("--blur_sig", default="0.5")
parser.add_argument("--cmp_qual", default="75")
parser.add_argument("--resize_size", type=int, default=256)
parser.add_argument("--resize_ratio", type=float, default=1.0)
# other
parser.add_argument("--norm_type", type=str, default="resnet") # normalization type
return parser
def parse_arguments(opt):
if not isinstance(opt.blur_sig, list):
opt.blur_sig = [float(s) for s in opt.blur_sig.split(",")]
if not isinstance(opt.cmp_qual, list):
opt.cmp_qual = [int(s) for s in opt.cmp_qual.split(",")]
print(opt.cmp_qual)
return opt
def make_post(opt):
transforms_list = list()
if opt.cropSize > 0:
print("\nUsing Post Random Crop\n")
transforms_list.append(Tv2.RandomCrop(opt.cropSize, pad_if_needed=True, padding_mode="symmetric"))
if len(transforms_list) == 0:
return None
else:
return Tv2.Compose(transforms_list)
def make_aug(opt):
# AUG
transforms_list_aug = list()
if (opt.resize_size > 0) and (opt.resize_prob > 0): # opt.resized_ratio
transforms_list_aug.append(
Tv2.RandomChoice(
[
Tv2.RandomResizedCrop(
size=opt.resize_size,
scale=(0.08, 1.0),
ratio=(opt.resize_ratio, 1.0 / opt.resize_ratio),
),
Tv2.RandomCrop([opt.resize_size])
],
p=[opt.resize_prob, 1 - opt.resize_prob],
)
)
if opt.jitter_prob > 0:
transforms_list_aug.append(
Tv2.RandomApply(
[
Tv2.ColorJitter(0.4, 0.4, 0.4, 0.1)
],
p=opt.jitter_prob
)
)
if opt.colordist_prob > 0:
transforms_list_aug.append(Tv2.RandomGrayscale(p=opt.colordist_prob))
if opt.cutout_prob > 0:
transforms_list_aug.append(create_cutout_transforms(opt.cutout_prob))
if opt.noise_prob > 0:
transforms_list_aug.append(
Tv2.Compose([
Tv2.ToImage(),
Tv2.ToDtype(torch.float32, scale=False),
Tv2.RandomApply(
[
Tv2.GaussianNoise(sigma=0.44)
],
p=opt.noise_prob
),
Tv2.ToDtype(torch.uint8, scale=False),
Tv2.ToPILImage(),
])
)
if opt.blur_prob > 0:
transforms_list_aug.append(
Tv2.RandomApply(
[
Tv2.GaussianBlur(
kernel_size=15,
sigma=opt.blur_sig,
)
],
p=opt.blur_prob
)
)
if opt.cmp_prob > 0:
transforms_list_aug.append(
Tv2.RandomApply(
[
Tv2.JPEG(
opt.cmp_qual
)
],
opt.cmp_prob,
)
)
transforms_list_aug.append(Tv2.Compose([Tv2.RandomHorizontalFlip(), Tv2.RandomVerticalFlip()]))
if len(transforms_list_aug) > 0:
return Tv2.Compose(transforms_list_aug)
else:
return None
def make_normalize(opt):
transforms_list = list()
if opt.norm_type == "resnet":
print("normalize RESNET")
transforms_list.append(Tv2.ToTensor())
transforms_list.append(
Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
else:
assert False
return Tv2.Compose(transforms_list)
def sample_discrete(s):
if len(s) == 1:
return s[0]
return random.choice(s)
def sample_continuous(s):
if len(s) == 1:
return s[0]
if len(s) == 2:
rg = s[1] - s[0]
return random.random() * rg + s[0]
raise ValueError("Length of iterable s should be 1 or 2.")
def create_cutout_transforms(p):
from albumentations import CoarseDropout
aug = CoarseDropout(
num_holes_range=(1,1),
hole_height_range=(1, 48),
hole_width_range=(1, 48),
fill=128,
p=p,
)
return transforms.Lambda(
lambda img: Image.fromarray(aug(image=np.array(img))["image"])
)