training_sem / libs /data /batch_sampler.py
kai-2054's picture
Initial commit: add code
cb0ad2d
import tqdm
import copy
import random
from collections import defaultdict
from libs.utils import logger
import numpy as np
class BucketSampler:
def __init__(self, dataset, world_size, rank_id, fix_batch_size=None, max_pixel_nums=None, max_batch_size=8, min_batch_size=1, seps=(100, 100, 20)):
self.dataset = dataset
self.world_size = world_size
self.rank_id = rank_id
self.seps = seps
self.fix_batch_size = fix_batch_size
self.max_batch_size = max_batch_size
self.min_batch_size = min_batch_size
self.max_pixel_nums = max_pixel_nums
assert (fix_batch_size is not None) or (max_pixel_nums is not None)
self.cal_buckets()
self.seed = 20
self.epoch = 0
def count_keys(self):
infos = []
for i in tqdm.tqdm(range(len(self.dataset))):
info = self.dataset.get_info(i)
infos.append(info)
return infos
def cal_buckets(self):
infos = self.count_keys()
np.save('count_keys.npy', infos)
min_sizes = None # (64, 18, 2)
max_sizes = None # (1223, 742, 2080)
for info in infos:
if min_sizes is None:
min_sizes = info
max_sizes = info
else: # get the max size of each item of tuple
min_sizes = tuple(min(min_sizes[idx], info[idx]) for idx in range(len(min_sizes)))
max_sizes = tuple(max(max_sizes[idx], info[idx]) for idx in range(len(max_sizes)))
assert (min_sizes is not None) and (len(self.seps) == len(min_sizes))
print('max sizes: {}, min size: {}'.format(max_sizes, min_sizes))
buckets = defaultdict(list)
for idx, info in enumerate(infos):
bucket_idxes = list()
for sep, size, min_size in zip(self.seps, info, min_sizes):
bucket_idx = (size - min_size) // sep
bucket_idxes.append(str(bucket_idx))
bucket_idxes = '-'.join(bucket_idxes)
buckets[bucket_idxes].append(idx)
np.save('buckets.npy', buckets)
valid_buckets = dict()
for bucket_key, bucket_samples in buckets.items():
if len(bucket_samples) < self.min_batch_size:
continue
if (self.fix_batch_size is not None) and (len(bucket_samples) < self.fix_batch_size):
continue
w, h, *_ = [(int(item) + 1) * sep + min_size for item, min_size, sep in zip(bucket_key.split('-'), min_sizes, self.seps)]
if self.fix_batch_size is not None:
if h * w * self.fix_batch_size > self.max_pixel_nums:
continue
else:
if h * w * self.min_batch_size > self.max_pixel_nums:
continue
if self.fix_batch_size is not None:
batch_size = self.fix_batch_size
else:
batch_size = min(self.max_batch_size, max(self.max_pixel_nums // (w * h), self.min_batch_size), len(bucket_samples))
valid_buckets[bucket_key] = dict(
samples=bucket_samples,
batch_size=batch_size
)
self.buckets = [valid_buckets[bucket_key] for bucket_key in sorted(valid_buckets.keys())]
total_nums = len(infos)
valid_nums = sum([len(item['samples']) for item in valid_buckets.values()])
logger.info('Total %d samples, but ignore %d samples' % (total_nums, total_nums - valid_nums))
def __iter__(self):
random_inst = random.Random(self.seed + self.epoch)
batches = list()
for bucket in self.buckets:
sample = copy.deepcopy(bucket['samples'])
batch_size = bucket['batch_size']
random_inst.shuffle(sample)
idx = 0
while idx < len(sample):
batch = sample[idx:idx + batch_size]
idx += batch_size
if len(batch) < self.min_batch_size:
continue
batches.append(batch)
random_inst.shuffle(batches)
align_nums = (len(batches) // self.world_size) * self.world_size
batches = batches[: align_nums]
for batch_idx in range(self.rank_id, len(batches), self.world_size):
yield batches[batch_idx]
def __len__(self):
batch_nums = 0
for bucket in self.buckets:
bucket_sample_nums = len(bucket["samples"])
bucket_bs = bucket['batch_size']
batch_nums += bucket_sample_nums // bucket_bs
if bucket_sample_nums % bucket_bs >= self.min_batch_size:
batch_nums += 1
return batch_nums // self.world_size
def set_epoch(self, epoch):
self.epoch = epoch