File size: 4,451 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import tqdm
import copy
import random
from collections import defaultdict
from libs.utils import logger
import numpy as np

class BucketSampler:
    def __init__(self, dataset, world_size, rank_id, fix_batch_size=None, max_pixel_nums=None, max_batch_size=8, min_batch_size=1, seps=(100, 100, 20)):
        self.dataset = dataset
        self.world_size = world_size
        self.rank_id = rank_id
        self.seps = seps
        self.fix_batch_size = fix_batch_size
        self.max_batch_size = max_batch_size
        self.min_batch_size = min_batch_size
        self.max_pixel_nums = max_pixel_nums
        assert (fix_batch_size is not None) or (max_pixel_nums is not None)
        self.cal_buckets()
        self.seed = 20
        self.epoch = 0

def count_keys(self):
    infos = []
    for i in tqdm.tqdm(range(len(self.dataset))):
        info = self.dataset.get_info(i)
        infos.append(info)
    return infos

def cal_buckets(self):
    infos = self.count_keys()
    np.save('count_keys.npy', infos)
    min_sizes = None # (64, 18, 2)
    max_sizes = None # (1223, 742, 2080)
    for info in infos:
        if min_sizes is None:
            min_sizes = info
            max_sizes = info
        else: # get the max size of each item of tuple
            min_sizes = tuple(min(min_sizes[idx], info[idx]) for idx in range(len(min_sizes)))
            max_sizes = tuple(max(max_sizes[idx], info[idx]) for idx in range(len(max_sizes))) 
    assert (min_sizes is not None) and (len(self.seps) == len(min_sizes))
    print('max sizes: {}, min size: {}'.format(max_sizes, min_sizes))
    buckets = defaultdict(list)
    for idx, info in enumerate(infos):
        bucket_idxes = list()
        for sep, size, min_size in zip(self.seps, info, min_sizes):
            bucket_idx = (size - min_size) // sep
            bucket_idxes.append(str(bucket_idx))
        bucket_idxes = '-'.join(bucket_idxes)
        buckets[bucket_idxes].append(idx)
    np.save('buckets.npy', buckets)

    valid_buckets = dict()
    for bucket_key, bucket_samples in buckets.items():
        if len(bucket_samples) < self.min_batch_size:
            continue
        if (self.fix_batch_size is not None) and (len(bucket_samples) < self.fix_batch_size):
            continue

        w, h, *_ = [(int(item) + 1) * sep + min_size for item, min_size, sep in zip(bucket_key.split('-'), min_sizes, self.seps)]
        if self.fix_batch_size is not None:
            if h * w * self.fix_batch_size > self.max_pixel_nums:
                continue
        else:
            if h * w * self.min_batch_size > self.max_pixel_nums:
                continue
        
        if self.fix_batch_size is not None:
            batch_size = self.fix_batch_size
        else:
            batch_size = min(self.max_batch_size, max(self.max_pixel_nums // (w * h), self.min_batch_size), len(bucket_samples))
        
        valid_buckets[bucket_key] = dict(
            samples=bucket_samples,
            batch_size=batch_size
        )

    self.buckets = [valid_buckets[bucket_key] for bucket_key in sorted(valid_buckets.keys())]
    total_nums = len(infos)
    valid_nums = sum([len(item['samples']) for item in valid_buckets.values()])
    logger.info('Total %d samples, but ignore %d samples' % (total_nums, total_nums - valid_nums))

def __iter__(self):
    random_inst = random.Random(self.seed + self.epoch)
    batches = list()
    for bucket in self.buckets:
        sample = copy.deepcopy(bucket['samples'])
        batch_size = bucket['batch_size']
        random_inst.shuffle(sample)
        idx = 0
        while idx < len(sample):
            batch = sample[idx:idx + batch_size]
            idx += batch_size
            if len(batch) < self.min_batch_size:
                continue
            batches.append(batch)
    random_inst.shuffle(batches)
    
    align_nums = (len(batches) // self.world_size) * self.world_size
    batches = batches[: align_nums]
    for batch_idx in range(self.rank_id, len(batches), self.world_size):
        yield batches[batch_idx]

def __len__(self):
    batch_nums = 0
    for bucket in self.buckets:
        bucket_sample_nums = len(bucket["samples"])
        bucket_bs = bucket['batch_size']
        batch_nums += bucket_sample_nums // bucket_bs
        if bucket_sample_nums % bucket_bs >= self.min_batch_size:
            batch_nums += 1

    return batch_nums // self.world_size

def set_epoch(self, epoch):
    self.epoch = epoch