Spaces:
Runtime error
Runtime error
GLIP-BLIP-Object-Detection-VQA
/
maskrcnn_benchmark
/data
/samplers
/iteration_based_batch_sampler.py
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| from torch.utils.data.sampler import BatchSampler | |
| class IterationBasedBatchSampler(BatchSampler): | |
| """ | |
| Wraps a BatchSampler, resampling from it until | |
| a specified number of iterations have been sampled | |
| """ | |
| def __init__(self, batch_sampler, num_iterations, start_iter=0): | |
| self.batch_sampler = batch_sampler | |
| self.num_iterations = num_iterations | |
| self.start_iter = start_iter | |
| def __iter__(self): | |
| iteration = self.start_iter | |
| while iteration <= self.num_iterations: | |
| # if the underlying sampler has a set_epoch method, like | |
| # DistributedSampler, used for making each process see | |
| # a different split of the dataset, then set it | |
| if hasattr(self.batch_sampler.sampler, "set_epoch"): | |
| self.batch_sampler.sampler.set_epoch(iteration) | |
| for batch in self.batch_sampler: | |
| iteration += 1 | |
| if iteration > self.num_iterations: | |
| break | |
| yield batch | |
| def __len__(self): | |
| return self.num_iterations | |