|
|
import os |
|
|
import torch |
|
|
import functools |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from PIL import Image, ImageFile |
|
|
from torch.utils.data import Dataset |
|
|
from tqdm import tqdm |
|
|
import re |
|
|
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] |
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
def has_file_allowed_extension(filename, extensions): |
|
|
"""Checks if a file is an allowed extension. |
|
|
Args: |
|
|
filename (string): path to a file |
|
|
extensions (iterable of strings): extensions to consider (lowercase) |
|
|
Returns: |
|
|
bool: True if the filename ends with one of given extensions |
|
|
""" |
|
|
filename_lower = filename.lower() |
|
|
return any(filename_lower.endswith(ext) for ext in extensions) |
|
|
|
|
|
|
|
|
def image_loader(image_name): |
|
|
if has_file_allowed_extension(image_name, IMG_EXTENSIONS): |
|
|
I = Image.open(image_name) |
|
|
return I.convert('RGB') |
|
|
|
|
|
|
|
|
def get_default_img_loader(): |
|
|
return functools.partial(image_loader) |
|
|
|
|
|
|
|
|
class ImageDataset2(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
self.data = pd.read_csv(csv_file, sep='\t', header=None) |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
I = self.loader(image_name) |
|
|
I = self.preprocess(I) |
|
|
I = I.unsqueeze(0) |
|
|
n_channels = 3 |
|
|
kernel_h = 224 |
|
|
kernel_w = 224 |
|
|
if (I.size(2) >= 1024) | (I.size(3) >= 1024): |
|
|
step = 48 |
|
|
else: |
|
|
step = 32 |
|
|
patches = I.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
|
|
|
assert patches.size(0) >= self.num_patch |
|
|
|
|
|
if self.test: |
|
|
sel_step = patches.size(0) // self.num_patch |
|
|
sel = torch.zeros(self.num_patch) |
|
|
for i in range(self.num_patch): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(self.num_patch, )) |
|
|
patches = patches[sel, ...] |
|
|
mos = self.data.iloc[index, 1] |
|
|
|
|
|
dist_type = self.data.iloc[index, 2] |
|
|
scene_content1 = self.data.iloc[index, 3] |
|
|
scene_content2 = self.data.iloc[index, 4] |
|
|
scene_content3 = self.data.iloc[index, 5] |
|
|
|
|
|
if scene_content2 == 'invalid': |
|
|
valid = 1 |
|
|
elif scene_content3 == 'invalid': |
|
|
valid = 2 |
|
|
else: |
|
|
valid = 3 |
|
|
|
|
|
sample = {'I': patches, 'mos': float(mos), 'dist_type': dist_type, 'scene_content1': scene_content1, |
|
|
'scene_content2':scene_content2, 'scene_content3':scene_content3, 'valid':valid} |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
class ImageDataset_qonly(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
if csv_file[-3:] == 'txt': |
|
|
data = pd.read_csv(csv_file, sep='\t', header=None) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
elif csv_file[-4:] == 'xlsx': |
|
|
data = pd.read_excel(csv_file, header=0) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
else: |
|
|
data = pd.read_csv(csv_file, header=0) |
|
|
if ('split' in data.columns) & (set != 3): |
|
|
self.data = data[data.split==set] |
|
|
else: |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
image_name = image_name.replace('\\', '/') |
|
|
I = self.loader(image_name) |
|
|
I = self.preprocess(I) |
|
|
I = I.unsqueeze(0) |
|
|
n_channels = 3 |
|
|
kernel_h = 224 |
|
|
kernel_w = 224 |
|
|
if (I.size(2) >= 1024) | (I.size(3) >= 1024): |
|
|
step = 48 |
|
|
else: |
|
|
step = 32 |
|
|
patches = I.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
|
|
|
assert patches.size(0) >= self.num_patch |
|
|
|
|
|
if self.test: |
|
|
sel_step = patches.size(0) // self.num_patch |
|
|
sel = torch.zeros(self.num_patch) |
|
|
for i in range(self.num_patch): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(self.num_patch, )) |
|
|
patches = patches[sel, ...] |
|
|
mos = self.data.iloc[index, self.mos_col] |
|
|
if self.data.shape[1] == 23: |
|
|
distortions = self.data.iloc[index, self.mos_col+1::2] |
|
|
distortions = distortions.to_numpy(dtype=float) |
|
|
distortions = torch.from_numpy(distortions) |
|
|
else: |
|
|
distortions = 0 |
|
|
|
|
|
sample = {'I': patches, 'mos': float(mos), 'dists':distortions} |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
|
|
|
class ImageDataset_llie(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
spatialFeat, |
|
|
preprocess, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
if csv_file[-3:] == 'txt': |
|
|
data = pd.read_csv(csv_file, sep='\t', header=None) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
elif csv_file[-4:] == 'xlsx': |
|
|
data = pd.read_excel(csv_file, header=0) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
else: |
|
|
data = pd.read_csv(csv_file, header=0) |
|
|
if ('split' in data.columns) & (set != 3): |
|
|
self.data = data[data.split==set] |
|
|
else: |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
self.spatialFeat = spatialFeat |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
image_name = image_name.replace('\\', '/') |
|
|
I = self.loader(image_name) |
|
|
I = self.preprocess(I) |
|
|
|
|
|
tmp = image_name.split('/')[-1] |
|
|
tmp = tmp.split('.')[0] |
|
|
|
|
|
spatial_feat = torch.from_numpy(np.load(os.path.join(self.spatialFeat, f'{tmp}.npy'))).view(-1) |
|
|
|
|
|
I = I.unsqueeze(0) |
|
|
n_channels = 3 |
|
|
kernel_h = 224 |
|
|
kernel_w = 224 |
|
|
if (I.size(2) >= 1024) | (I.size(3) >= 1024): |
|
|
step = 48 |
|
|
else: |
|
|
step = 32 |
|
|
patches = I.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
|
|
|
assert patches.size(0) >= self.num_patch |
|
|
self.num_patch = np.minimum(patches.size(0), self.num_patch) |
|
|
if self.test: |
|
|
sel_step = patches.size(0) // self.num_patch |
|
|
sel = torch.zeros(self.num_patch) |
|
|
for i in range(self.num_patch): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(self.num_patch, )) |
|
|
patches = patches[sel, ...] |
|
|
|
|
|
mos = self.data.iloc[index, self.mos_col] |
|
|
if self.data.shape[1] == 23: |
|
|
distortions = self.data.iloc[index, self.mos_col+1::2] |
|
|
distortions = distortions.to_numpy(dtype=float) |
|
|
distortions = torch.from_numpy(distortions) |
|
|
else: |
|
|
distortions = 0 |
|
|
|
|
|
sample = {'I': patches, 'spatial_feat':spatial_feat, 'mos': float(mos), 'dists':distortions} |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
|
|
|
class ImageDataset_llie_naflex(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
if csv_file[-3:] == 'txt': |
|
|
data = pd.read_csv(csv_file, sep='\t', header=None) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
elif csv_file[-4:] == 'xlsx': |
|
|
data = pd.read_excel(csv_file, header=0) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
else: |
|
|
data = pd.read_csv(csv_file, header=0) |
|
|
if ('split' in data.columns) & (set != 3): |
|
|
self.data = data[data.split==set] |
|
|
else: |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
image_name = image_name.replace('\\', '/') |
|
|
I = self.loader(image_name) |
|
|
|
|
|
mos = self.data.iloc[index, self.mos_col] |
|
|
if self.data.shape[1] == 23: |
|
|
distortions = self.data.iloc[index, self.mos_col+1::2] |
|
|
distortions = distortions.to_numpy(dtype=float) |
|
|
distortions = torch.from_numpy(distortions) |
|
|
else: |
|
|
distortions = 0 |
|
|
|
|
|
|
|
|
|
|
|
return I, float(mos), distortions |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
class ImageDataset_sr_naflex(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
data = pd.read_excel(csv_file, header=0) |
|
|
self.data = data |
|
|
|
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, 'SR', self.data.iloc[index, 0]) |
|
|
image_name = image_name.replace('\\', '/') |
|
|
|
|
|
im_name = self.data.iloc[index, 0] |
|
|
|
|
|
I = self.loader(image_name) |
|
|
|
|
|
mos = self.data.iloc[index, 3] |
|
|
|
|
|
return I, float(mos) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
class ImageDataset_diqa_naflex(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
data = pd.read_csv(csv_file, header=0) |
|
|
self.data = data |
|
|
|
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, 'res', self.data.iloc[index, 0]) |
|
|
image_name = image_name.replace('\\', '/') |
|
|
I = self.loader(image_name) |
|
|
if self.data.shape[1] == 5: |
|
|
image_name2 = os.path.join(self.img_dir, 'ori', self.data.iloc[index, 1]) |
|
|
image_name2 = image_name.replace('\\', '/') |
|
|
I_ref = self.loader(image_name2) |
|
|
overall_mos = 0.8*self.data.iloc[index, 2] + 1 |
|
|
sharp_mos = 0.8*self.data.iloc[index, 3] + 1 |
|
|
color_mos = 0.8*self.data.iloc[index, 4] + 1 |
|
|
else: |
|
|
I_ref = I |
|
|
overall_mos = 0.8*self.data.iloc[index, 1] + 1 |
|
|
sharp_mos = 0.8*self.data.iloc[index, 2] + 1 |
|
|
color_mos = 0.8*self.data.iloc[index, 3] + 1 |
|
|
|
|
|
return I, I_ref, float(overall_mos), float(sharp_mos), float(color_mos) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageDataset_llie2(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
if csv_file[-3:] == 'txt': |
|
|
data = pd.read_csv(csv_file, sep='\t', header=None) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
elif csv_file[-4:] == 'xlsx': |
|
|
data = pd.read_excel(csv_file, header=0) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
else: |
|
|
data = pd.read_csv(csv_file, header=0) |
|
|
if ('split' in data.columns) & (set != 3): |
|
|
self.data = data[data.split==set] |
|
|
else: |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
image_name = image_name.replace('\\', '/') |
|
|
I = self.loader(image_name) |
|
|
I = self.preprocess(I) |
|
|
|
|
|
I = I.unsqueeze(0) |
|
|
n_channels = 3 |
|
|
kernel_h = 224 |
|
|
kernel_w = 224 |
|
|
if (I.size(2) >= 1024) | (I.size(3) >= 1024): |
|
|
step = 48 |
|
|
else: |
|
|
step = 32 |
|
|
patches = I.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
|
|
|
assert patches.size(0) >= self.num_patch |
|
|
self.num_patch = np.minimum(patches.size(0), self.num_patch) |
|
|
if self.test: |
|
|
sel_step = patches.size(0) // self.num_patch |
|
|
sel = torch.zeros(self.num_patch) |
|
|
for i in range(self.num_patch): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(self.num_patch, )) |
|
|
patches = patches[sel, ...] |
|
|
|
|
|
mos = self.data.iloc[index, self.mos_col] |
|
|
if self.data.shape[1] == 23: |
|
|
distortions = self.data.iloc[index, self.mos_col+1::2] |
|
|
distortions = distortions.to_numpy(dtype=float) |
|
|
distortions = torch.from_numpy(distortions) |
|
|
else: |
|
|
distortions = 0 |
|
|
|
|
|
sample = {'I': patches, 'mos': float(mos), 'dists':distortions} |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
|
|
|
class ImageDataset_pseudo_label(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
pseudo_label, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
""" |
|
|
self.data = pd.read_csv(csv_file, header=None) |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.pseudo_label = pseudo_label |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: A list of dicts with keys 'I' and 'mos' |
|
|
""" |
|
|
|
|
|
image_name = self.data.iloc[index, 0] |
|
|
labels = [] |
|
|
all_patches = [] |
|
|
methods = list(self.pseudo_label.keys()) |
|
|
for method in methods: |
|
|
if method == 'GT': |
|
|
llie_name = image_name |
|
|
elif method == 'NeRCo': |
|
|
llie_name = method + '_' + image_name[:-4] + '_fake_B.png' |
|
|
else: |
|
|
llie_name = method + '_' + image_name |
|
|
image_path = os.path.join(self.img_dir, method, llie_name) |
|
|
I = self.loader(image_path) |
|
|
I = self.preprocess(I) |
|
|
label = self.pseudo_label[method] |
|
|
|
|
|
I = I.unsqueeze(0) |
|
|
n_channels = 3 |
|
|
kernel_h = 224 |
|
|
kernel_w = 224 |
|
|
if (I.size(2) >= 1024) | (I.size(3) >= 1024): |
|
|
step = 48 |
|
|
else: |
|
|
step = 32 |
|
|
patches = I.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
|
|
|
assert patches.size(0) >= self.num_patch |
|
|
self.num_patch = np.minimum(patches.size(0), self.num_patch) |
|
|
if self.test: |
|
|
sel_step = patches.size(0) // self.num_patch |
|
|
sel = torch.zeros(self.num_patch) |
|
|
for i in range(self.num_patch): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(self.num_patch,)) |
|
|
patches = patches[sel, ...] |
|
|
|
|
|
labels.append(label) |
|
|
all_patches.append(patches) |
|
|
|
|
|
I = torch.cat(all_patches, dim=0) |
|
|
labels = torch.tensor(labels) |
|
|
sample = {'I': I, 'mos': labels} |
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageDataset_oppo(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
num_patch, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
self.data = pd.read_csv(csv_file, header=0) |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
|
|
|
def __convertnan__(self, value): |
|
|
if pd.isna(value): |
|
|
value = 'free' |
|
|
return value |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
I = self.loader(image_name) |
|
|
I = self.preprocess(I) |
|
|
I = I.unsqueeze(0) |
|
|
n_channels = 3 |
|
|
kernel_h = 224 |
|
|
kernel_w = 224 |
|
|
if (I.size(2) >= 1024) | (I.size(3) >= 1024): |
|
|
step = 48 |
|
|
else: |
|
|
step = 32 |
|
|
patches = I.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
|
|
|
assert patches.size(0) >= self.num_patch |
|
|
|
|
|
if self.test: |
|
|
sel_step = patches.size(0) // self.num_patch |
|
|
sel = torch.zeros(self.num_patch) |
|
|
for i in range(self.num_patch): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(self.num_patch, )) |
|
|
patches = patches[sel, ...] |
|
|
|
|
|
scene = self.data.iloc[index, 1] |
|
|
mode = self.data.iloc[index, 2] |
|
|
focal_length = self.data.iloc[index, 3] |
|
|
compare_x200p = self.data.iloc[index, 4] |
|
|
tone_level = self.__convertnan__(self.data.iloc[index, 5]) |
|
|
tone_global_issue = self.__convertnan__(self.data.iloc[index, 6]) |
|
|
tone_local_issue = self.__convertnan__(self.data.iloc[index, 7]) |
|
|
tone_local_issue_region = self.__convertnan__(self.data.iloc[index, 8]) |
|
|
color_level = self.__convertnan__(self.data.iloc[index, 9]) |
|
|
color_global_issue = self.__convertnan__(self.data.iloc[index, 10]) |
|
|
color_local_issue = self.__convertnan__(self.data.iloc[index, 11]) |
|
|
color_local_issue_region = self.__convertnan__(self.data.iloc[index, 12]) |
|
|
|
|
|
sample = {'I': patches, 'scene': scene.lower(), 'mode': mode.lower(), 'focal_length':focal_length.lower(), 'compare_x200p':compare_x200p.lower(), |
|
|
'tone_level':tone_level.lower(), 'tone_global_issue':tone_global_issue.lower(), 'tone_local_issue':tone_local_issue.lower(), |
|
|
'tone_local_issue_region':tone_local_issue_region.lower(), 'color_level':color_level.lower(), |
|
|
'color_global_issue':color_global_issue.lower(), 'color_local_issue':color_local_issue.lower(), |
|
|
'color_local_issue_region':color_local_issue_region.lower()} |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
|
|
|
|
|
|
class ImageDataset_llie_general(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
if csv_file[-3:] == 'txt': |
|
|
data = pd.read_csv(csv_file, sep='\t', header=None) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
elif csv_file[-4:] == 'xlsx': |
|
|
data = pd.read_excel(csv_file, header=0) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
else: |
|
|
data = pd.read_csv(csv_file, header=0) |
|
|
if ('split' in data.columns) & (set != 3): |
|
|
self.data = data[data.split==set] |
|
|
else: |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess = preprocess |
|
|
self.test = test |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
image_name = image_name.replace('\\', '/') |
|
|
I = self.loader(image_name) |
|
|
I = self.preprocess(I) |
|
|
|
|
|
mos = self.data.iloc[index, self.mos_col] |
|
|
sample = {'I': I, 'mos': float(mos)} |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data.index) |
|
|
|
|
|
|
|
|
class ImageDataset_ms(Dataset): |
|
|
def __init__(self, csv_file, |
|
|
img_dir, |
|
|
preprocess1, |
|
|
preprocess2, |
|
|
preprocess3, |
|
|
num_patch, |
|
|
set, |
|
|
test, |
|
|
get_loader=get_default_img_loader): |
|
|
""" |
|
|
Args: |
|
|
csv_file (string): Path to the csv file with annotations. |
|
|
img_dir (string): Directory of the images. |
|
|
transform (callable, optional): transform to be applied on a sample. |
|
|
""" |
|
|
if csv_file[-3:] == 'txt': |
|
|
data = pd.read_csv(csv_file, sep='\t', header=None) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
elif csv_file[-4:] == 'xlsx': |
|
|
data = pd.read_excel(csv_file, header=0) |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
else: |
|
|
data = pd.read_csv(csv_file, header=0) |
|
|
if ('split' in data.columns) & (set != 3): |
|
|
self.data = data[data.split==set] |
|
|
else: |
|
|
self.data = data |
|
|
self.mos_col = 1 |
|
|
print('%d csv data successfully loaded!' % self.__len__()) |
|
|
self.img_dir = img_dir |
|
|
self.loader = get_loader() |
|
|
self.preprocess1 = preprocess1 |
|
|
self.preprocess2 = preprocess2 |
|
|
self.preprocess3 = preprocess3 |
|
|
self.num_patch = num_patch |
|
|
self.test = test |
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Args: |
|
|
index (int): Index |
|
|
Returns: |
|
|
samples: a Tensor that represents a video segment. |
|
|
""" |
|
|
image_name = os.path.join(self.img_dir, self.data.iloc[index, 0]) |
|
|
I = self.loader(image_name) |
|
|
|
|
|
num_patch_per_scale = (self.num_patch - 1) // 2 |
|
|
I1 = self.preprocess1(I) |
|
|
I1 = I1.unsqueeze(0) |
|
|
I2 = self.preprocess1(I) |
|
|
I2 = I2.unsqueeze(0) |
|
|
I3 = self.preprocess1(I) |
|
|
I3 = I3.unsqueeze(0) |
|
|
|
|
|
I_global = I1 |
|
|
|
|
|
n_channels = 3 |
|
|
kernel_h = 224 |
|
|
kernel_w = 224 |
|
|
|
|
|
all_patches = [I_global] |
|
|
|
|
|
step = 16 |
|
|
patches = I2.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
if self.test: |
|
|
sel_step = patches.size(0) // num_patch_per_scale |
|
|
sel = torch.zeros(num_patch_per_scale) |
|
|
for i in range(num_patch_per_scale): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(num_patch_per_scale,)) |
|
|
patches = patches[sel, ...] |
|
|
all_patches.append(patches) |
|
|
|
|
|
step = 32 |
|
|
patches = I3.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(0, 2, 3, 1, 4, 5).reshape(-1, |
|
|
n_channels, |
|
|
kernel_h, |
|
|
kernel_w) |
|
|
if self.test: |
|
|
sel_step = patches.size(0) // num_patch_per_scale |
|
|
sel = torch.zeros(num_patch_per_scale) |
|
|
for i in range(num_patch_per_scale): |
|
|
sel[i] = sel_step * i |
|
|
sel = sel.long() |
|
|
else: |
|
|
sel = torch.randint(low=0, high=patches.size(0), size=(num_patch_per_scale,)) |
|
|
patches = patches[sel, ...] |
|
|
all_patches.append(patches) |
|
|
|
|
|
|
|
|
all_patches = torch.cat(all_patches, 0) |
|
|
mos = self.data.iloc[index, 1] |
|
|
if self.data.shape[1] == 23: |
|
|
distortions = self.data.iloc[index, self.mos_col + 1::2] |
|
|
distortions = distortions.to_numpy(dtype=float) |
|
|
distortions = torch.from_numpy(distortions) |
|
|
else: |
|
|
distortions = 0 |
|
|
|
|
|
sample = {'I': all_patches, 'mos': float(mos), 'dists': distortions} |
|
|
|
|
|
return sample |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
|
|
|
|