File size: 5,504 Bytes
0b69a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose
import os, sys
from argparse import ArgumentParser
from typing import Union, Tuple, List, Dict

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)

import datasets


def calc_bin_center(
    bins: List[Tuple[float, float]],
    count_stats: Dict[int, int],
) -> Tuple[List[float], List[int]]:
    """
    Calculate the representative value for each bin based on the count statistics.

    `bins` may look like: [(0, 0), (1, 1), (2, 3), (4, 6), (7, float('inf'))]
    `count_stats` may look like: {0: 10, 1: 20, 2: 30, 3: 40, 4: 50, 5: 60, 6: 70, 7: 80, 8: 90, 9: 100}
    In this example, for bin (2, 3), we have 30 samples of 2 and 40 samples of 3 that fall into this bin.
    The representative value for this bin is (30 * 2 + 40 * 3) / (30 + 40) = 2.6.

    The returned list will have the same length as `bins`, and each element is the representative value for the corresponding bin.
    """
    bin_counts = [0] * len(bins)
    bin_sums = [0] * len(bins)
    for k, v in count_stats.items():
        for i, (start, end) in enumerate(bins):
            if start <= int(k) <= end:
                bin_counts[i] += int(v)
                bin_sums[i] += int(v) * int(k)
                break
    assert all(c > 0 for c in bin_counts), f"Expected all bin_counts to be greater than 0, got {bin_counts}. Consider to re-design the bins {bins}."
    bin_centers = [s / c for s, c in zip(bin_sums, bin_counts)]
    return bin_centers, bin_counts


def get_dataloader(args: ArgumentParser, split: str = "train") -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]:
    ddp = args.nprocs > 1
    if split == "train":  # train, strong augmentation
        transforms = [
            datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.aug_min_scale, args.aug_max_scale)),
            datasets.RandomHorizontalFlip(),
        ]
        if args.aug_brightness > 0 or args.aug_contrast > 0 or args.aug_saturation > 0 or args.aug_hue > 0:
            transforms.append(datasets.ColorJitter(
                brightness=args.aug_brightness, contrast=args.aug_contrast, saturation=args.aug_saturation, hue=args.aug_hue
            ))
        if args.aug_blur_prob > 0 and args.aug_kernel_size > 0:
            transforms.append(datasets.RandomApply([
                datasets.GaussianBlur(kernel_size=args.aug_kernel_size),
            ], p=args.aug_blur_prob))
        if args.aug_saltiness > 0 or args.aug_spiciness > 0:
            transforms.append(datasets.PepperSaltNoise(
                saltiness=args.aug_saltiness, spiciness=args.aug_spiciness,
            ))
        transforms = Compose(transforms)

    elif args.sliding_window and args.resize_to_multiple:
        transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride)

    else:
        transforms = None

    dataset_class = datasets.InMemoryCrowd if args.in_memory_dataset else datasets.Crowd
    prefetch_factor = None if args.num_workers == 0 else 3
    persistent_workers = False if args.num_workers == 0 else True
    
    dataset = dataset_class(
        dataset=args.dataset,
        split=split,
        transforms=transforms,
        sigma=None,
        return_filename=False,
        num_crops=args.num_crops if split == "train" else 1,
    )

    if ddp and split == "train":  # data_loader for training in DDP
        sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=True, seed=args.seed+args.local_rank)
        data_loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            sampler=sampler,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=datasets.collate_fn,
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers,
        )
        return data_loader, sampler

    elif (not ddp) and split == "train":  # data_loader for training
        data_loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=datasets.collate_fn,
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers,
        )
        return data_loader, None

    elif ddp and split == "val":
        sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=False)
        data_loader = DataLoader(
            dataset,
            batch_size=1,  # Use batch size 1 for evaluation
            sampler=sampler,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=datasets.collate_fn,
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers,
        )
        return data_loader
    
    else:  # (not ddp) and split == "val"
        data_loader = DataLoader(
            dataset,
            batch_size=1,  # Use batch size 1 for evaluation
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=datasets.collate_fn,
            prefetch_factor=prefetch_factor,
            persistent_workers=persistent_workers,
        )
        return data_loader