File size: 7,295 Bytes
3090535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import numpy as np
import torch
import torch.distributed as dist
from data.data_transforms import build_transform
from data.dataset_cls import MIQACLSDataset
from data.dataset_det import MIQADETDataset
from data.dataset_ins import MIQAINSDataset
from data.samplers import PatchDistributedSampler, SubsetRandomSampler
def build_dataset(config):
"""
Build corresponding MIQA dataset based on configuration
Args:
config: Configuration object containing dataset type, paths, and other information
Returns:
train_dataset: Training dataset
test_dataset: Test dataset
"""
# Build classification-oriented MIQA task dataset
if config.dataset == "miqa_cls":
if not config.eval_only:
train_dataset = MIQACLSDataset(
root=config.path_miqa_cls, # Dataset root directory
split_file=config.train_split_file, # Training set split file
patch_num=config.patch_num, # Number of image patches, default: 1
transforms=build_transform(is_train=True, config=config), # Training data augmentation
metric_type=config.metric_type, # Evaluation metric type
return_all_metrics=config.return_all_metrics, # Whether to return all metrics
is_train=True # Mark as training mode
)
test_dataset = MIQACLSDataset(
root=config.path_miqa_cls,
split_file=config.val_split_file, # Validation set split file
patch_num=config.patch_num,
transforms=build_transform(is_train=False, config=config), # Validation data transforms (no augmentation)
metric_type=config.metric_type,
return_all_metrics=config.return_all_metrics,
is_train=False # Mark as validation mode
)
# Build detection-oriented MIQA task dataset
elif config.dataset == "miqa_det":
if not config.eval_only:
train_dataset = MIQADETDataset(
root=config.path_miqa_det, # Detection dataset root directory
split_file=config.train_split_file,
patch_num=config.patch_num,
transforms=build_transform(is_train=True, config=config),
metric_type=config.metric_type,
return_all_metrics=False, # Detection task doesn't return all metrics
is_train=True
)
test_dataset = MIQADETDataset(
root=config.path_miqa_det,
split_file=config.val_split_file,
patch_num=config.patch_num,
transforms=build_transform(is_train=False, config=config),
metric_type=config.metric_type,
return_all_metrics=False,
is_train=False
)
# Build instance segmentation-oriented MIQA task dataset
elif config.dataset == "miqa_ins":
if not config.eval_only:
train_dataset = MIQAINSDataset(
det_root=config.path_miqa_det, # Detection data root directory
ins_root=config.path_label_ins, # Instance segmentation label root directory
split_file=config.train_split_file,
patch_num=config.patch_num,
transforms=build_transform(is_train=True, config=config),
metric_type=config.metric_type,
return_all_metrics=False,
is_train=True
)
test_dataset = MIQAINSDataset(
det_root=config.path_miqa_det,
ins_root=config.path_miqa_ins,
split_file=config.val_split_file,
patch_num=config.patch_num,
transforms=build_transform(is_train=False, config=config),
metric_type=config.metric_type,
return_all_metrics=False,
is_train=False
)
else:
raise NotImplementedError("We only support common IQA dataset now.")
# Return only test_dataset if evaluating
if config.eval_only:
return test_dataset
else:
return train_dataset, test_dataset
def build_loader(config):
"""
Build training and validation data loaders for distributed training
Args:
config: Configuration object containing all hyperparameters and path information
Returns:
dataset_train: Training dataset
dataset_val: Validation dataset
data_loader_train: Training data loader
data_loader_val: Validation data loader
"""
# Defrost config to allow modifications
config.defrost()
# Build training and validation datasets
dataset_train, dataset_val = build_dataset(config=config)
# Freeze config to prevent accidental modifications
config.freeze()
# Print successful dataset building info (including local rank and global rank)
print(
f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset"
)
print(
f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset"
)
# Get total number of processes and current process's global rank for distributed training
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
# Select appropriate sampler for training data based on configuration
if config.ZIP_MODE and config.CACHE_MODE == "part":
# If using ZIP mode with "part" cache mode, use subset random sampler
# Each process handles specific data indices (stride sampling)
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
sampler_train = SubsetRandomSampler(indices)
else:
# Otherwise use PyTorch's distributed sampler to ensure each process gets different data subset
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
# Select validation data sampler based on configuration
if config.test.SEQUENTIAL:
# If sequential validation is required, use sequential sampler
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
# Otherwise use patch distributed sampler (custom sampler)
sampler_val = PatchDistributedSampler(dataset_val)
# Create training data loader
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler_train,
batch_size=config.batch_size_train, # Training batch size
num_workers=config.num_workers_train,
pin_memory=config.pin_memory_train,
drop_last=True, # Drop the last incomplete batch
)
# Create validation data loader
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
sampler=sampler_val,
batch_size=config.batch_size_val, # Validation batch size
shuffle=False,
num_workers=config.num_workers_val,
pin_memory=config.pin_memory_val,
drop_last=False, # Keep the last incomplete batch
)
return dataset_train, dataset_val, data_loader_train, data_loader_val
|