File size: 12,003 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2025 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.

#  * Loader Factory, Fast Collate, CUDA Prefetcher
#  * Prefetcher and Fast Collate inspired by NVIDIA APEX example at
#  * https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
#  *  Copyright 2019, Ross Wightman
#  *--------------------------------------------------------------------------------------------*/

from functools import partial

import torch
from timm.data.dataset import IterableImageDataset
from timm.data.distributed_sampler import (OrderedDistributedSampler,
                                           RepeatAugSampler)
from timm.data.loader import (MultiEpochsDataLoader, PrefetchLoader,
                              _worker_init, fast_collate)

from image_classification.pt.src.datasets.augmentations.augs import (
    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

import os

class PredictionDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.image_paths = [
            os.path.join(folder_path, f)
            for f in os.listdir(folder_path)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ]
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
        ])

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        return img, img_path

    def __len__(self):
        return len(self.image_paths)
    
def create_loader(
    dataset,
    input_size,
    batch_size,
    is_training=False,
    use_prefetcher=True,
    no_aug=False,
    re_prob=0.,
    re_mode='const',
    re_count=1,
    re_num_splits=0,
    num_aug_repeats=0,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    num_workers=1,
    distributed=False,
    collate_fn=None,
    pin_memory=False,
    fp16=False,  # deprecated, use img_dtype
    img_dtype=torch.float32,
    device=torch.device('cuda'),
    use_multi_epochs_loader=False,
    persistent_workers=True,
    worker_seeding='all',
):
    if isinstance(input_size, int):
        input_size = (3, input_size, input_size)

    if isinstance(dataset, IterableImageDataset):
        # give Iterable datasets early knowledge of num_workers so that sample estimates
        # are correct before worker processes are launched
        dataset.set_loader_cfg(num_workers=num_workers)

    sampler = None
    if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
        if is_training:
            if num_aug_repeats:
                sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
            else:
                sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            # This will add extra duplicate entries to result in equal num
            # of samples per-process, will slightly alter validation results
            sampler = OrderedDistributedSampler(dataset)
    else:
        assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"

    if collate_fn is None:
        collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate

    loader_class = torch.utils.data.DataLoader
    if use_multi_epochs_loader:
        loader_class = MultiEpochsDataLoader

    loader_args = {
        'batch_size': batch_size,
        'shuffle': not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
        'num_workers': num_workers,
        'sampler': sampler,
        'collate_fn': collate_fn,
        'pin_memory': pin_memory,
        'drop_last': is_training,
        'worker_init_fn': partial(_worker_init, worker_seeding=worker_seeding),
        'persistent_workers': persistent_workers
    }
    try:
        loader = loader_class(dataset, **loader_args)
    except TypeError:
        loader_args.pop('persistent_workers')  # only in Pytorch 1.7+
        loader = loader_class(dataset, **loader_args)
    if use_prefetcher:
        prefetch_re_prob = re_prob if is_training and not no_aug else 0.
        loader = PrefetchLoader(  # pylint: disable=unexpected-keyword-arg
            loader,
            mean=mean,
            std=std,
            channels=input_size[0],
            device=device,
            fp16=fp16,  # deprecated, use img_dtype
            img_dtype=img_dtype,
            re_prob=prefetch_re_prob,
            re_mode=re_mode,
            re_count=re_count,
            re_num_splits=re_num_splits
        )
    return loader


from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler as DS


def get_dataloader(
    dataset, batch_size=32, num_workers=4, fp16=False, distributed=False, shuffle=False,
    collate_fn=None, device="cuda"
):
    if collate_fn is None:
        collate_fn = default_collate
    def half_precision(x):
        x = collate_fn(x)
        x = [_x.half() if isinstance(_x, torch.FloatTensor) else _x for _x in x]
        return x
 
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn= half_precision if fp16 else collate_fn,
        sampler=DS(dataset) if distributed else None,
    )
    return dataloader


from timm import utils
from timm.data import FastCollateMixup


def prepare_kwargs_for_dataloader(cfg):
    #####################################
    # Hydra config to flat config with last keys
    #####################################
    from types import SimpleNamespace

    from common.utils import flatten_config
    # exceptional handling of scale variable. as config has two variables of same name.
    # and when u flatten config they both becomes same
    scale = None
    if "data_augmentation" in cfg and "scale" in cfg.data_augmentation:
        scale = cfg.data_augmentation.scale
    if not scale:
        scale = [0.08, 1.0]
    # convert hydra heirarchical config to flat config
    args_raw = flatten_config(cfg)
    # convert dict based config to argument parser type
    if isinstance(args_raw, dict):
        args = SimpleNamespace(**args_raw)
    else:
        args = args_raw
    #####################################
    args.prefetcher = not getattr(args, 'no_prefetcher', False)
    if not hasattr(args, "batch_size") or args.batch_size is None:
        args.batch_size = 128
    if not hasattr(args, "validation_batch_size") or args.validation_batch_size is None:
        args.validation_batch_size = args.batch_size
    num_aug_splits = 0
    collate_fn = None
    args.mixup = getattr(args, "mixup", 0)
    args.cutmix = getattr(args, "cutmix", 0)
    args.cutmix_minmax = getattr(args, "cutmix_minmax", None)

    mixup_active = (args.mixup > 0) or (args.cutmix > 0) or (args.cutmix_minmax is not None)
    # with current config following "if" is false
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            cutmix_minmax=args.cutmix_minmax,
            prob=getattr(args, "mixup_prob", 1.0),
            switch_prob=getattr(args, "mixup_switch_prob", 0.5),
            mode=getattr(args, "mixup_mode", 'batch'),
            label_smoothing=getattr(args, "smoothing", 0.1),
            num_classes=args.num_classes
        )
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
    
    device = args.device # already set in the main py
    
    if getattr(args, 'aug_splits', 0) > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits
    extra_loader_kwargs = {}
   
    if getattr(args, 'dataset_name', '') in ['imagenet','custom']:
        extra_loader_kwargs = {
            'train_split': getattr(args, 'train_split', 'train'),
            'val_split': getattr(args, 'val_split', 'val'),
            'class_map': getattr(args, 'class_map', None),
            'seed': getattr(args, 'seed', None),
            'repeats': getattr(args, 'epoch_repeats', None),
        }
    # create data loaders w/ augmentation pipeline
    test_interpolation = 'bicubic'
    train_interpolation = getattr(args, 'train_interpolation', "random")
    args.no_aug = getattr(args, 'no_aug', False)
    if args.no_aug or not train_interpolation:
        train_interpolation = test_interpolation
    
    args.mean= getattr(args, 'mean', [0.485, 0.456, 0.406]) 
    args.std= getattr(args, 'std', [0.229, 0.224, 0.225]) 
    batch_size = getattr(args, 'batch_size', 128)
    validation_batch_size = getattr(args, 'validation_batch_size', batch_size)
    loader_kwargs = dict(
        img_size=tuple(args.input_shape),
        batch_size=batch_size,
        test_batch_size=validation_batch_size,
        download=getattr(args, 'data_download', False),
        distributed=getattr(args, 'distributed', False), # created same way as before Nikhil
        use_prefetcher=args.prefetcher, # created same way as before Nikhil
        no_aug=args.no_aug,
        re_prob=getattr(args, 'reprob', 0),
        re_mode=getattr(args, 'remode', 'pixel'),
        re_count=getattr(args, 'recount', 1),
        re_split=getattr(args, 'resplit', False),
        scale=scale,
        ratio=getattr(args, 'ratio', [0.75, 1.33]),
        hflip=getattr(args, 'hflip', 0.5),
        vflip=getattr(args, 'vflip', 0.0),
        color_jitter=getattr(args, 'color_jitter', 0.4),
        auto_augment=getattr(args, 'aa', None),
        num_aug_repeats=getattr(args, 'aug_repeats', 0),
        num_aug_splits=num_aug_splits,
        train_interpolation=train_interpolation,
        test_interpolation=test_interpolation, # the way it was done in 
        mean=args.mean, # from config directly Nikhil
        std=args.std, # from config directly Nikhil
        num_workers=getattr(args, 'workers', 4),
        collate_fn=collate_fn, # created same way as before Nikhil
        pin_memory=getattr(args, 'pin_mem', False),
        device=device, # TODO multi gpu
        use_multi_epochs_loader=getattr(args, 'use_multi_epochs_loader', False),
        worker_seeding=getattr(args, 'worker_seeding', False),
        test_path=getattr(args, 'test_path', None),
        prediction_path=getattr(args, 'prediction_path', None),
        quantization_path=getattr(args, 'quantization_path', None),
        quantization_split = getattr(args, 'quantization_split', 0.1),
        num_classes = args.num_classes, # has to be there
        **extra_loader_kwargs)
    return loader_kwargs

def get_from_config(cfg, key_path, default=None):
    """
    Safely get a nested value from cfg using dot-separated keys.
    Works for OmegaConf DictConfig, dict, or objects with attributes.
    
    Example:
        get_from_config(cfg, "a.b.c", default=42)
    """
    keys = key_path.split(".")
    current = cfg

    for key in keys:
        if isinstance(current, dict):
            if key not in current:
                return default
            current = current[key]
        elif hasattr(current, key):
            current = get_from_config(current, key)
        else:
            return default
    return current