| import numpy as np |
| import torch |
| from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
| class DistributedSamplerWrapper(DistributedSampler): |
| """Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode. |
| It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each |
| process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler, |
| and load a subset of the original dataset that is exclusive to it. |
| |
| .. note: |
| Dataset is assumed to be of constant size. |
| |
| Args: |
| sampler: Sampler used for subsampling. |
| num_replicas (int, optional): Number of processes participating in distributed training. By default, |
| world_size is retrieved from the current distributed group. |
| rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved |
| from the current distributed group. |
| shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True. |
| seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be |
| identical across all processes in the distributed group. Default: 0. |
| |
| Reference: https://github.com/pytorch/pytorch/issues/23430 |
| |
| """ |
|
|
| def __init__( |
| self, |
| sampler, |
| num_replicas: int = None, |
| rank: int = None, |
| shuffle: bool = True, |
| seed: int = 0, |
| ): |
| super().__init__( |
| sampler, |
| num_replicas=num_replicas, |
| rank=rank, |
| shuffle=shuffle, |
| seed=seed, |
| ) |
|
|
| def __iter__(self): |
| indices = list(self.dataset)[: self.total_size] |
|
|
| |
| indices += indices[: (self.total_size - len(indices))] |
| assert len(indices) == self.total_size, f"{len(indices)} != {self.total_size}" |
|
|
| |
| offset = self.num_samples * self.rank |
| indices = indices[offset : offset + self.num_samples] |
| assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}" |
|
|
| return iter(indices) |
|
|
| def set_epoch(self, epoch): |
| super().set_epoch(epoch) |
| if hasattr(self.dataset, "set_epoch"): |
| self.dataset.set_epoch(epoch) |
| elif hasattr(self.dataset, "generator"): |
| self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch) |
|
|
| def state_dict(self): |
| return self.dataset.state_dict() |
|
|
| def load_state_dict(self, state_dict): |
| self.dataset.load_state_dict(state_dict) |
|
|
|
|
| |
| class NoamLR(torch.optim.lr_scheduler._LRScheduler): |
| def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): |
| self.warmup_steps = float(warmup_steps) |
| super().__init__(optimizer, last_epoch) |
|
|
| def get_lr(self): |
| step = max(self.last_epoch, 1) |
| return [ |
| base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) |
| for base_lr in self.base_lrs |
| ] |
|
|
|
|
| class NoamLRStepConstant(torch.optim.lr_scheduler._LRScheduler): |
| def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1, threshold_step=100): |
| self.warmup_steps = float(warmup_steps) |
| self.threshold_step = threshold_step |
| super().__init__(optimizer, last_epoch) |
|
|
| def get_lr(self): |
| step = min(max(self.last_epoch, 1), self.threshold_step) |
| return [ |
| base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) |
| for base_lr in self.base_lrs |
| ] |
|
|
|
|
| class NoamLRStepDecay(torch.optim.lr_scheduler._LRScheduler): |
| def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1, threshold_step=100): |
| self.warmup_steps = float(warmup_steps) |
| self.threshold_step = threshold_step |
| super().__init__(optimizer, last_epoch) |
|
|
| def get_lr(self): |
| step = max(self.last_epoch, 1) |
| if step >= self.threshold_step: |
| self.threshold_step -= 1 |
| step = max(self.threshold_step, 1) |
| return [ |
| base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) |
| for base_lr in self.base_lrs |
| ] |
|
|
| |
| class StepwiseGradualLR(torch.optim.lr_scheduler._LRScheduler): |
| """Hardcoded step-wise learning rate scheduling. |
| Necessary for CapacitronVAE""" |
|
|
| def __init__(self, optimizer, gradual_learning_rates, last_epoch=-1): |
| self.gradual_learning_rates = gradual_learning_rates |
| super().__init__(optimizer, last_epoch) |
|
|
| def get_lr(self): |
| step = max(self.last_epoch, 1) |
| step_thresholds = [] |
| rates = [] |
| for values in self.gradual_learning_rates: |
| step_thresholds.append(values[0]) |
| rates.append(values[1]) |
|
|
| boolean_indeces = np.less_equal(step_thresholds, step) |
| try: |
| last_true = np.where(boolean_indeces == True)[0][-1] |
| except IndexError: |
| |
| pass |
| lr = rates[np.max(last_true, 0)] |
|
|
| |
| lr = rates[-1] if step > step_thresholds[-1] else lr |
| |
| lr = rates[0] if step < step_thresholds[1] else lr |
|
|
| return np.tile(lr, len(self.base_lrs)) |
|
|