File size: 5,504 Bytes
0b69a1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose
import os, sys
from argparse import ArgumentParser
from typing import Union, Tuple, List, Dict
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import datasets
def calc_bin_center(
bins: List[Tuple[float, float]],
count_stats: Dict[int, int],
) -> Tuple[List[float], List[int]]:
"""
Calculate the representative value for each bin based on the count statistics.
`bins` may look like: [(0, 0), (1, 1), (2, 3), (4, 6), (7, float('inf'))]
`count_stats` may look like: {0: 10, 1: 20, 2: 30, 3: 40, 4: 50, 5: 60, 6: 70, 7: 80, 8: 90, 9: 100}
In this example, for bin (2, 3), we have 30 samples of 2 and 40 samples of 3 that fall into this bin.
The representative value for this bin is (30 * 2 + 40 * 3) / (30 + 40) = 2.6.
The returned list will have the same length as `bins`, and each element is the representative value for the corresponding bin.
"""
bin_counts = [0] * len(bins)
bin_sums = [0] * len(bins)
for k, v in count_stats.items():
for i, (start, end) in enumerate(bins):
if start <= int(k) <= end:
bin_counts[i] += int(v)
bin_sums[i] += int(v) * int(k)
break
assert all(c > 0 for c in bin_counts), f"Expected all bin_counts to be greater than 0, got {bin_counts}. Consider to re-design the bins {bins}."
bin_centers = [s / c for s, c in zip(bin_sums, bin_counts)]
return bin_centers, bin_counts
def get_dataloader(args: ArgumentParser, split: str = "train") -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]:
ddp = args.nprocs > 1
if split == "train": # train, strong augmentation
transforms = [
datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.aug_min_scale, args.aug_max_scale)),
datasets.RandomHorizontalFlip(),
]
if args.aug_brightness > 0 or args.aug_contrast > 0 or args.aug_saturation > 0 or args.aug_hue > 0:
transforms.append(datasets.ColorJitter(
brightness=args.aug_brightness, contrast=args.aug_contrast, saturation=args.aug_saturation, hue=args.aug_hue
))
if args.aug_blur_prob > 0 and args.aug_kernel_size > 0:
transforms.append(datasets.RandomApply([
datasets.GaussianBlur(kernel_size=args.aug_kernel_size),
], p=args.aug_blur_prob))
if args.aug_saltiness > 0 or args.aug_spiciness > 0:
transforms.append(datasets.PepperSaltNoise(
saltiness=args.aug_saltiness, spiciness=args.aug_spiciness,
))
transforms = Compose(transforms)
elif args.sliding_window and args.resize_to_multiple:
transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride)
else:
transforms = None
dataset_class = datasets.InMemoryCrowd if args.in_memory_dataset else datasets.Crowd
prefetch_factor = None if args.num_workers == 0 else 3
persistent_workers = False if args.num_workers == 0 else True
dataset = dataset_class(
dataset=args.dataset,
split=split,
transforms=transforms,
sigma=None,
return_filename=False,
num_crops=args.num_crops if split == "train" else 1,
)
if ddp and split == "train": # data_loader for training in DDP
sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=True, seed=args.seed+args.local_rank)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader, sampler
elif (not ddp) and split == "train": # data_loader for training
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader, None
elif ddp and split == "val":
sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=False)
data_loader = DataLoader(
dataset,
batch_size=1, # Use batch size 1 for evaluation
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader
else: # (not ddp) and split == "val"
data_loader = DataLoader(
dataset,
batch_size=1, # Use batch size 1 for evaluation
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)
return data_loader
|