| from torch.optim.lr_scheduler import _LRScheduler | |
| class PolyLr(_LRScheduler): | |
| def __init__(self, optimizer, gamma, max_iteration, minimum_lr=0, warmup_iteration=0, last_epoch=-1): | |
| self.gamma = gamma | |
| self.max_iteration = max_iteration | |
| self.minimum_lr = minimum_lr | |
| self.warmup_iteration = warmup_iteration | |
| self.last_epoch = None | |
| self.base_lrs = [] | |
| super(PolyLr, self).__init__(optimizer, last_epoch) | |
| def poly_lr(self, base_lr, step): | |
| return (base_lr - self.minimum_lr) * ((1 - (step / self.max_iteration)) ** self.gamma) + self.minimum_lr | |
| def warmup_lr(self, base_lr, alpha): | |
| return base_lr * (1 / 10.0 * (1 - alpha) + alpha) | |
| def get_lr(self): | |
| if self.last_epoch < self.warmup_iteration: | |
| alpha = self.last_epoch / self.warmup_iteration | |
| lrs = [min(self.warmup_lr(base_lr, alpha), self.poly_lr(base_lr, self.last_epoch)) for base_lr in | |
| self.base_lrs] | |
| else: | |
| lrs = [self.poly_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] | |
| return lrs |