from ImageDataset import ImageDataset from torch.utils.data import DataLoader from ImageDataset2 import ImageDataset2, ImageDataset_qonly, ImageDataset_oppo, ImageDataset_llie, ImageDataset_pseudo_label, ImageDataset_llie2, ImageDataset_llie_general, ImageDataset_ms, ImageDataset_llie_naflex, ImageDataset_sr_naflex, ImageDataset_diqa_naflex from ImageDataset import ImageDataset_SPAQ, ImageDataset_TID, ImageDataset_PIPAL, ImageDataset_ava from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, CenterCrop, RandomCrop, Resize from torchvision import transforms import torch from PIL import Image try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC def set_dataset(csv_file, bs, data_set, num_workers, preprocess, num_patch, test): data = ImageDataset2( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_dataset_oppo(csv_file, bs, data_set, num_workers, preprocess, num_patch, test): data = ImageDataset_oppo( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_spaq(csv_file, bs, data_set, num_workers, preprocess, num_patch, test): data = ImageDataset_SPAQ( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_tid(csv_file, bs, data_set, num_workers, preprocess, num_patch, test): data = ImageDataset_TID( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_pipal(csv_file, bs, data_set, num_workers, preprocess, num_patch, test): data = ImageDataset_PIPAL( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_ava(csv_file, bs, data_set, num_workers, preprocess, num_patch, test): data = ImageDataset_ava( npy_file='./ava_test.npy', img_dir=data_set, preprocess=preprocess) loader = DataLoader(data, batch_size=bs, shuffle=False, pin_memory=True, num_workers=num_workers) return loader def set_dataset_qonly(csv_file, bs, data_set, num_workers, preprocess, num_patch, test, set): data = ImageDataset_qonly( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, set=set, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_dataset_llie(csv_file, bs, data_set, spatialFeat, num_workers, preprocess, num_patch, test, set): data = ImageDataset_llie( csv_file=csv_file, img_dir=data_set, spatialFeat=spatialFeat, num_patch=num_patch, set=set, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_dataset_llie2(csv_file, bs, data_set, num_workers, preprocess, num_patch, test, set): data = ImageDataset_llie2( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, set=set, test=test, preprocess=preprocess, ) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_dataset_llie_naflex(csv_file, bs, data_set, num_workers, preprocess, num_patch, test, set): data = ImageDataset_llie_naflex( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, set=set, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers, collate_fn=custom_collect_fn) return loader def custom_collect_fn(batch): vids, moss, dists = zip(*batch) moss = torch.tensor(moss) dists = torch.stack(dists) # batch = zip(*batch) # batch = next(batch) # # vids = [] # moss = [] # dists = [] # # for item in batch: # I = item[0] # mos = item[1] # dist = item[2] # vids.append(I) # moss.append(mos) # dists.append(dist) # # moss = torch.stack(moss) # dists = torch.stack(dists) return vids, moss, dists def set_dataset_sr_naflex(csv_file, bs, data_set, num_workers, preprocess, num_patch, test, set): data = ImageDataset_sr_naflex( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, set=set, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers, collate_fn=custom_collect_fn2) return loader def custom_collect_fn2(batch): vids, moss = zip(*batch) moss = torch.tensor(moss) return vids, moss def set_dataset_diqa_naflex(csv_file, bs, data_set, num_workers, preprocess, num_patch, test, set): data = ImageDataset_diqa_naflex( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, set=set, test=test, preprocess=preprocess) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers, collate_fn=custom_collect_fn3) return loader def custom_collect_fn3(batch): vids, vids_r, overall_moss, sharp_moss, color_moss = zip(*batch) overall_moss = torch.tensor(overall_moss) sharp_moss = torch.tensor(sharp_moss) color_moss = torch.tensor(color_moss) return vids, vids_r, overall_moss, sharp_moss, color_moss def set_dataset_general(csv_file, bs, data_set, num_workers, preprocess, test, set): data = ImageDataset_llie_general( csv_file=csv_file, img_dir=data_set, set=set, test=test, preprocess=preprocess, ) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader def set_dataset_pseudo_label(csv_file, bs, data_set, num_workers, preprocess, num_patch, test, set, pseudo_label): data = ImageDataset_pseudo_label( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, set=set, test=test, pseudo_label=pseudo_label, preprocess=preprocess) if test: shuffle = False else: shuffle = True # loader = DataLoader(data, batch_size=1, shuffle=shuffle, collate_fn=group_collate_fn, # pin_memory=True, num_workers=num_workers) loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader class AdaptiveResize(object): """Resize the input PIL Image to the given size adaptively. Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` """ def __init__(self, size, interpolation=InterpolationMode.BILINEAR, image_size=None): assert isinstance(size, int) self.size = size self.interpolation = interpolation if image_size is not None: self.image_size = image_size else: self.image_size = None def __call__(self, img): """ Args: img (PIL Image): Image to be scaled. Returns: PIL Image: Rescaled image. """ h, w = img.size if self.image_size is not None: if h < self.image_size or w < self.image_size: return transforms.Resize(self.image_size, self.interpolation)(img) if h < self.size or w < self.size: return img else: return transforms.Resize(self.size, self.interpolation)(img) def _convert_image_to_rgb(image): return image.convert("RGB") def _preprocess2(): return Compose([ _convert_image_to_rgb, AdaptiveResize(768), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def _preprocess22(size=448): return Compose([ _convert_image_to_rgb, AdaptiveResize(size), ToTensor(), #Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def _preprocess3(): return Compose([ _convert_image_to_rgb, AdaptiveResize(768), RandomHorizontalFlip(), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def _preprocess33(size=448, crop_size=384): return Compose([ _convert_image_to_rgb, AdaptiveResize(size), transforms.RandomCrop(crop_size), RandomHorizontalFlip(), ToTensor(), #Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def _preprocess333(size=448): return Compose([ _convert_image_to_rgb, AdaptiveResize(size), RandomHorizontalFlip(), ToTensor(), #Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def _preprocess_siglip(): return Compose([ _convert_image_to_rgb, AdaptiveResize(768), #Resize(512+32), #CenterCrop(224), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def _preprocess_siglip_train(): return Compose([ _convert_image_to_rgb, AdaptiveResize(768), #Resize(512+32), #RandomCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def _preprocess_siglip2(): return Compose([ _convert_image_to_rgb, AdaptiveResize(768), #CenterCrop(224), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def _preprocess_siglip_train2(): return Compose([ _convert_image_to_rgb, AdaptiveResize(768), #RandomCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() if p.grad is not None: p.grad.data = p.grad.data.float() def _preprocess_scale1_train(): return Compose([ _convert_image_to_rgb, Resize(224), RandomCrop(224), RandomHorizontalFlip(), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def _preprocess_scale1_test(): return Compose([ _convert_image_to_rgb, Resize(224), CenterCrop(224), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def _preprocess_scale_test(size): return Compose([ _convert_image_to_rgb, AdaptiveResize(size), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def _preprocess_scale_train(size): return Compose([ _convert_image_to_rgb, AdaptiveResize(size), RandomHorizontalFlip(), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def set_dataset_ms(csv_file, bs, data_set, num_workers, preprocess1, preprocess2, preprocess3, num_patch, test, set): data = ImageDataset_ms( csv_file=csv_file, img_dir=data_set, num_patch=num_patch, set=set, test=test, preprocess1=preprocess1, preprocess2=preprocess2, preprocess3=preprocess3, ) if test: shuffle = False else: shuffle = True loader = DataLoader(data, batch_size=bs, shuffle=shuffle, pin_memory=True, num_workers=num_workers) return loader