| import numpy as np
|
| import torch
|
| import torch.distributed as dist
|
|
|
| from data.data_transforms import build_transform
|
| from data.dataset_cls import MIQACLSDataset
|
| from data.dataset_det import MIQADETDataset
|
| from data.dataset_ins import MIQAINSDataset
|
| from data.samplers import PatchDistributedSampler, SubsetRandomSampler
|
|
|
|
|
| def build_dataset(config):
|
| """
|
| Build corresponding MIQA dataset based on configuration
|
|
|
| Args:
|
| config: Configuration object containing dataset type, paths, and other information
|
|
|
| Returns:
|
| train_dataset: Training dataset
|
| test_dataset: Test dataset
|
| """
|
|
|
|
|
| if config.dataset == "miqa_cls":
|
| if not config.eval_only:
|
| train_dataset = MIQACLSDataset(
|
| root=config.path_miqa_cls,
|
| split_file=config.train_split_file,
|
| patch_num=config.patch_num,
|
| transforms=build_transform(is_train=True, config=config),
|
| metric_type=config.metric_type,
|
| return_all_metrics=config.return_all_metrics,
|
| is_train=True
|
| )
|
|
|
| test_dataset = MIQACLSDataset(
|
| root=config.path_miqa_cls,
|
| split_file=config.val_split_file,
|
| patch_num=config.patch_num,
|
| transforms=build_transform(is_train=False, config=config),
|
| metric_type=config.metric_type,
|
| return_all_metrics=config.return_all_metrics,
|
| is_train=False
|
| )
|
|
|
|
|
| elif config.dataset == "miqa_det":
|
| if not config.eval_only:
|
| train_dataset = MIQADETDataset(
|
| root=config.path_miqa_det,
|
| split_file=config.train_split_file,
|
| patch_num=config.patch_num,
|
| transforms=build_transform(is_train=True, config=config),
|
| metric_type=config.metric_type,
|
| return_all_metrics=False,
|
| is_train=True
|
| )
|
|
|
| test_dataset = MIQADETDataset(
|
| root=config.path_miqa_det,
|
| split_file=config.val_split_file,
|
| patch_num=config.patch_num,
|
| transforms=build_transform(is_train=False, config=config),
|
| metric_type=config.metric_type,
|
| return_all_metrics=False,
|
| is_train=False
|
| )
|
|
|
|
|
| elif config.dataset == "miqa_ins":
|
| if not config.eval_only:
|
| train_dataset = MIQAINSDataset(
|
| det_root=config.path_miqa_det,
|
| ins_root=config.path_label_ins,
|
| split_file=config.train_split_file,
|
| patch_num=config.patch_num,
|
| transforms=build_transform(is_train=True, config=config),
|
| metric_type=config.metric_type,
|
| return_all_metrics=False,
|
| is_train=True
|
| )
|
|
|
| test_dataset = MIQAINSDataset(
|
| det_root=config.path_miqa_det,
|
| ins_root=config.path_miqa_ins,
|
| split_file=config.val_split_file,
|
| patch_num=config.patch_num,
|
| transforms=build_transform(is_train=False, config=config),
|
| metric_type=config.metric_type,
|
| return_all_metrics=False,
|
| is_train=False
|
| )
|
|
|
| else:
|
| raise NotImplementedError("We only support common IQA dataset now.")
|
|
|
|
|
| if config.eval_only:
|
| return test_dataset
|
| else:
|
| return train_dataset, test_dataset
|
|
|
|
|
| def build_loader(config):
|
| """
|
| Build training and validation data loaders for distributed training
|
|
|
| Args:
|
| config: Configuration object containing all hyperparameters and path information
|
|
|
| Returns:
|
| dataset_train: Training dataset
|
| dataset_val: Validation dataset
|
| data_loader_train: Training data loader
|
| data_loader_val: Validation data loader
|
| """
|
|
|
| config.defrost()
|
|
|
|
|
| dataset_train, dataset_val = build_dataset(config=config)
|
|
|
|
|
| config.freeze()
|
|
|
|
|
| print(
|
| f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset"
|
| )
|
| print(
|
| f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset"
|
| )
|
|
|
|
|
| num_tasks = dist.get_world_size()
|
| global_rank = dist.get_rank()
|
|
|
|
|
| if config.ZIP_MODE and config.CACHE_MODE == "part":
|
|
|
|
|
| indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
|
| sampler_train = SubsetRandomSampler(indices)
|
| else:
|
|
|
| sampler_train = torch.utils.data.DistributedSampler(
|
| dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| )
|
|
|
|
|
| if config.test.SEQUENTIAL:
|
|
|
| sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| else:
|
|
|
| sampler_val = PatchDistributedSampler(dataset_val)
|
|
|
|
|
| data_loader_train = torch.utils.data.DataLoader(
|
| dataset_train,
|
| sampler=sampler_train,
|
| batch_size=config.batch_size_train,
|
| num_workers=config.num_workers_train,
|
| pin_memory=config.pin_memory_train,
|
| drop_last=True,
|
| )
|
|
|
|
|
| data_loader_val = torch.utils.data.DataLoader(
|
| dataset_val,
|
| sampler=sampler_val,
|
| batch_size=config.batch_size_val,
|
| shuffle=False,
|
| num_workers=config.num_workers_val,
|
| pin_memory=config.pin_memory_val,
|
| drop_last=False,
|
| )
|
|
|
| return dataset_train, dataset_val, data_loader_train, data_loader_val
|
|
|