| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import numpy as np |
| | import math |
| | import sys |
| | from torch.utils.data import Sampler |
| |
|
| | __all__=['BatchSampler', 'DistributedBatchSampler', 'RandomSampler', 'SequentialSampler'] |
| | class BatchSampler(Sampler): |
| | def __init__(self, sampler, batch_size): |
| | self.sampler = sampler |
| | self.batch_size = batch_size |
| |
|
| | def __iter__(self): |
| | batch = [] |
| | for idx in self.sampler: |
| | batch.append(idx) |
| | if len(batch)==self.batch_size: |
| | yield batch |
| | batch = [] |
| | if len(batch)>0: |
| | yield batch |
| |
|
| | def __len__(self): |
| | return (len(self.sampler) + self.batch_size - 1)//self.batch_size |
| |
|
| | class DistributedBatchSampler(Sampler): |
| | def __init__(self, sampler, rank=0, world_size = 1, drop_last = False): |
| | self.sampler = sampler |
| | self.rank = rank |
| | self.world_size = world_size |
| | self.drop_last = drop_last |
| |
|
| | def __iter__(self): |
| | for b in self.sampler: |
| | if len(b)%self.world_size != 0: |
| | if self.drop_last: |
| | break |
| | else: |
| | b.extend([b[0] for _ in range(self.world_size-len(b)%self.world_size)]) |
| | chunk_size = len(b)//self.world_size |
| | yield b[self.rank*chunk_size:(self.rank+1)*chunk_size] |
| |
|
| | def __len__(self): |
| | return len(self.sampler) |
| |
|
| | class RandomSampler(Sampler): |
| | def __init__(self, total_samples:int, data_seed:int = 0): |
| | self.indices = np.array(np.arange(total_samples)) |
| | self.rng = np.random.RandomState(data_seed) |
| |
|
| | def __iter__(self): |
| | self.rng.shuffle(self.indices) |
| | for i in self.indices: |
| | yield i |
| |
|
| | def __len__(self): |
| | return len(self.indices) |
| |
|
| | class SequentialSampler(Sampler): |
| | def __init__(self, total_samples:int): |
| | self.indices = np.array(np.arange(total_samples)) |
| |
|
| | def __iter__(self): |
| | for i in self.indices: |
| | yield i |
| |
|
| | def __len__(self): |
| | return len(self.indices) |
| |
|