miqa
File size: 7,295 Bytes
3090535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import numpy as np
import torch
import torch.distributed as dist

from data.data_transforms import build_transform
from data.dataset_cls import MIQACLSDataset
from data.dataset_det import MIQADETDataset
from data.dataset_ins import MIQAINSDataset
from data.samplers import PatchDistributedSampler, SubsetRandomSampler


def build_dataset(config):
    """

    Build corresponding MIQA dataset based on configuration



    Args:

        config: Configuration object containing dataset type, paths, and other information



    Returns:

        train_dataset: Training dataset

        test_dataset: Test dataset

    """

    # Build classification-oriented MIQA task dataset
    if config.dataset == "miqa_cls":
        if not config.eval_only:
            train_dataset = MIQACLSDataset(
                root=config.path_miqa_cls,  # Dataset root directory
                split_file=config.train_split_file,  # Training set split file
                patch_num=config.patch_num,  # Number of image patches, default: 1
                transforms=build_transform(is_train=True, config=config),  # Training data augmentation
                metric_type=config.metric_type,  # Evaluation metric type
                return_all_metrics=config.return_all_metrics,  # Whether to return all metrics
                is_train=True  # Mark as training mode
            )

        test_dataset = MIQACLSDataset(
            root=config.path_miqa_cls,
            split_file=config.val_split_file,  # Validation set split file
            patch_num=config.patch_num,
            transforms=build_transform(is_train=False, config=config),  # Validation data transforms (no augmentation)
            metric_type=config.metric_type,
            return_all_metrics=config.return_all_metrics,
            is_train=False  # Mark as validation mode
        )

    # Build detection-oriented MIQA task dataset
    elif config.dataset == "miqa_det":
        if not config.eval_only:
            train_dataset = MIQADETDataset(
                root=config.path_miqa_det,  # Detection dataset root directory
                split_file=config.train_split_file,
                patch_num=config.patch_num,
                transforms=build_transform(is_train=True, config=config),
                metric_type=config.metric_type,
                return_all_metrics=False,  # Detection task doesn't return all metrics
                is_train=True
            )

        test_dataset = MIQADETDataset(
            root=config.path_miqa_det,
            split_file=config.val_split_file,
            patch_num=config.patch_num,
            transforms=build_transform(is_train=False, config=config),
            metric_type=config.metric_type,
            return_all_metrics=False,
            is_train=False
        )

    # Build instance segmentation-oriented MIQA task dataset
    elif config.dataset == "miqa_ins":
        if not config.eval_only:
            train_dataset = MIQAINSDataset(
                det_root=config.path_miqa_det,  # Detection data root directory
                ins_root=config.path_label_ins,  # Instance segmentation label root directory
                split_file=config.train_split_file,
                patch_num=config.patch_num,
                transforms=build_transform(is_train=True, config=config),
                metric_type=config.metric_type,
                return_all_metrics=False,
                is_train=True
            )

        test_dataset = MIQAINSDataset(
            det_root=config.path_miqa_det,
            ins_root=config.path_miqa_ins,
            split_file=config.val_split_file,
            patch_num=config.patch_num,
            transforms=build_transform(is_train=False, config=config),
            metric_type=config.metric_type,
            return_all_metrics=False,
            is_train=False
        )

    else:
        raise NotImplementedError("We only support common IQA dataset now.")
        # Return only test_dataset if evaluating

    if config.eval_only:
        return test_dataset
    else:
        return train_dataset, test_dataset


def build_loader(config):
    """

    Build training and validation data loaders for distributed training



    Args:

        config: Configuration object containing all hyperparameters and path information



    Returns:

        dataset_train: Training dataset

        dataset_val: Validation dataset

        data_loader_train: Training data loader

        data_loader_val: Validation data loader

    """
    # Defrost config to allow modifications
    config.defrost()

    # Build training and validation datasets
    dataset_train, dataset_val = build_dataset(config=config)

    # Freeze config to prevent accidental modifications
    config.freeze()

    # Print successful dataset building info (including local rank and global rank)
    print(
        f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset"
    )
    print(
        f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset"
    )

    # Get total number of processes and current process's global rank for distributed training
    num_tasks = dist.get_world_size()
    global_rank = dist.get_rank()

    # Select appropriate sampler for training data based on configuration
    if config.ZIP_MODE and config.CACHE_MODE == "part":
        # If using ZIP mode with "part" cache mode, use subset random sampler
        # Each process handles specific data indices (stride sampling)
        indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
        sampler_train = SubsetRandomSampler(indices)
    else:
        # Otherwise use PyTorch's distributed sampler to ensure each process gets different data subset
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )

    # Select validation data sampler based on configuration
    if config.test.SEQUENTIAL:
        # If sequential validation is required, use sequential sampler
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        # Otherwise use patch distributed sampler (custom sampler)
        sampler_val = PatchDistributedSampler(dataset_val)

    # Create training data loader
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=config.batch_size_train,  # Training batch size
        num_workers=config.num_workers_train,
        pin_memory=config.pin_memory_train,
        drop_last=True,  # Drop the last incomplete batch
    )

    # Create validation data loader
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=config.batch_size_val,  # Validation batch size
        shuffle=False,
        num_workers=config.num_workers_val,
        pin_memory=config.pin_memory_val,
        drop_last=False,  # Keep the last incomplete batch
    )

    return dataset_train, dataset_val, data_loader_train, data_loader_val