Spaces:
Sleeping
Sleeping
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torchvision.transforms.v2 import Compose | |
| import os, sys | |
| from argparse import ArgumentParser | |
| from typing import Union, Tuple, List, Dict | |
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.append(parent_dir) | |
| import datasets | |
| def calc_bin_center( | |
| bins: List[Tuple[float, float]], | |
| count_stats: Dict[int, int], | |
| ) -> Tuple[List[float], List[int]]: | |
| """ | |
| Calculate the representative value for each bin based on the count statistics. | |
| `bins` may look like: [(0, 0), (1, 1), (2, 3), (4, 6), (7, float('inf'))] | |
| `count_stats` may look like: {0: 10, 1: 20, 2: 30, 3: 40, 4: 50, 5: 60, 6: 70, 7: 80, 8: 90, 9: 100} | |
| In this example, for bin (2, 3), we have 30 samples of 2 and 40 samples of 3 that fall into this bin. | |
| The representative value for this bin is (30 * 2 + 40 * 3) / (30 + 40) = 2.6. | |
| The returned list will have the same length as `bins`, and each element is the representative value for the corresponding bin. | |
| """ | |
| bin_counts = [0] * len(bins) | |
| bin_sums = [0] * len(bins) | |
| for k, v in count_stats.items(): | |
| for i, (start, end) in enumerate(bins): | |
| if start <= int(k) <= end: | |
| bin_counts[i] += int(v) | |
| bin_sums[i] += int(v) * int(k) | |
| break | |
| bin_centers = [] | |
| for i, (s, c) in enumerate(zip(bin_sums, bin_counts)): | |
| if c > 0: | |
| bin_centers.append(s / c) | |
| else: | |
| start, end = bins[i] | |
| if end == float('inf'): | |
| bin_centers.append(float(start)) | |
| else: | |
| bin_centers.append(float(start + end) / 2) | |
| return bin_centers, bin_counts | |
| def get_dataloader(args: ArgumentParser, split: str = "train") -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]: | |
| ddp = args.nprocs > 1 | |
| if split == "train": # train, strong augmentation | |
| transforms = [ | |
| datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.aug_min_scale, args.aug_max_scale)), | |
| datasets.RandomHorizontalFlip(), | |
| ] | |
| if args.aug_brightness > 0 or args.aug_contrast > 0 or args.aug_saturation > 0 or args.aug_hue > 0: | |
| transforms.append(datasets.ColorJitter( | |
| brightness=args.aug_brightness, contrast=args.aug_contrast, saturation=args.aug_saturation, hue=args.aug_hue | |
| )) | |
| if args.aug_blur_prob > 0 and args.aug_kernel_size > 0: | |
| transforms.append(datasets.RandomApply([ | |
| datasets.GaussianBlur(kernel_size=args.aug_kernel_size), | |
| ], p=args.aug_blur_prob)) | |
| if args.aug_saltiness > 0 or args.aug_spiciness > 0: | |
| transforms.append(datasets.PepperSaltNoise( | |
| saltiness=args.aug_saltiness, spiciness=args.aug_spiciness, | |
| )) | |
| transforms = Compose(transforms) | |
| elif args.sliding_window and args.resize_to_multiple: | |
| transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride) | |
| else: | |
| transforms = None | |
| dataset_class = datasets.InMemoryCrowd if args.in_memory_dataset else datasets.Crowd | |
| prefetch_factor = None if args.num_workers == 0 else 3 | |
| persistent_workers = False if args.num_workers == 0 else True | |
| dataset = dataset_class( | |
| dataset=args.dataset, | |
| split=split, | |
| transforms=transforms, | |
| sigma=None, | |
| return_filename=False, | |
| num_crops=args.num_crops if split == "train" else 1, | |
| num_classes=args.num_classes, | |
| ) | |
| if ddp and split == "train": # data_loader for training in DDP | |
| sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=True, seed=args.seed+args.local_rank) | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| sampler=sampler, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| collate_fn=datasets.collate_fn, | |
| prefetch_factor=prefetch_factor, | |
| persistent_workers=persistent_workers, | |
| ) | |
| return data_loader, sampler | |
| elif (not ddp) and split == "train": # data_loader for training | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| collate_fn=datasets.collate_fn, | |
| prefetch_factor=prefetch_factor, | |
| persistent_workers=persistent_workers, | |
| ) | |
| return data_loader, None | |
| elif ddp and split == "val": | |
| sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=False) | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=1, # Use batch size 1 for evaluation | |
| sampler=sampler, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| collate_fn=datasets.collate_fn, | |
| prefetch_factor=prefetch_factor, | |
| persistent_workers=persistent_workers, | |
| ) | |
| return data_loader | |
| else: # (not ddp) and split == "val" | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=1, # Use batch size 1 for evaluation | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| collate_fn=datasets.collate_fn, | |
| prefetch_factor=prefetch_factor, | |
| persistent_workers=persistent_workers, | |
| ) | |
| return data_loader | |