Spaces:
Sleeping
Sleeping
| ### | |
| # Modified by Francesco Laiti - date 23 February 2024 | |
| # Fetched from https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/optim/lr_scheduler.py | |
| ### | |
| import torch | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| AVAI_SCHEDS = ["single_step", "multi_step", "cosine"] | |
| class _BaseWarmupScheduler(_LRScheduler): | |
| def __init__( | |
| self, optimizer, successor, warmup_epoch, last_epoch=-1, verbose=False | |
| ): | |
| self.successor = successor | |
| self.warmup_epoch = warmup_epoch | |
| super().__init__(optimizer, last_epoch, verbose) | |
| def get_lr(self): | |
| raise NotImplementedError | |
| def step(self, epoch=None): | |
| if self.last_epoch >= self.warmup_epoch: | |
| self.successor.step(epoch) | |
| self._last_lr = self.successor.get_last_lr() | |
| else: | |
| super().step(epoch) | |
| class ConstantWarmupScheduler(_BaseWarmupScheduler): | |
| def __init__( | |
| self, optimizer, successor, warmup_epoch, cons_lr, last_epoch=-1, verbose=False | |
| ): | |
| self.cons_lr = cons_lr | |
| super().__init__(optimizer, successor, warmup_epoch, last_epoch, verbose) | |
| def get_lr(self): | |
| if self.last_epoch >= self.warmup_epoch: | |
| return self.successor.get_last_lr() | |
| return [self.cons_lr for _ in self.base_lrs] | |
| class LinearWarmupScheduler(_BaseWarmupScheduler): | |
| def __init__( | |
| self, optimizer, successor, warmup_epoch, min_lr, last_epoch=-1, verbose=False | |
| ): | |
| self.min_lr = min_lr | |
| super().__init__(optimizer, successor, warmup_epoch, last_epoch, verbose) | |
| def get_lr(self): | |
| if self.last_epoch >= self.warmup_epoch: | |
| return self.successor.get_last_lr() | |
| if self.last_epoch == 0: | |
| return [self.min_lr for _ in self.base_lrs] | |
| return [lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs] | |
| def build_lr_scheduler( | |
| optimizer, | |
| lr_scheduler, | |
| max_epoch, | |
| warmup_epoch=0, | |
| warmup_recount=False, | |
| warmup_type=None, | |
| warmup_cons_lr=0.01, | |
| warmup_min_lr=0.001, | |
| stepsize=None, | |
| gamma=None, | |
| ): | |
| """ | |
| A function wrapper for building a learning rate scheduler. | |
| Args: | |
| optimizer (Optimizer): an Optimizer. | |
| lr_scheduler (str): Type of learning rate scheduler. | |
| stepsize (int or list/tuple): Step size for learning rate decay. | |
| gamma (float): Multiplicative factor of learning rate decay. | |
| max_epoch (int): Maximum number of epochs. | |
| warmup_epoch (int, optional): Number of warmup epochs. | |
| warmup_recount (bool, optional): Recount option for warmup. | |
| warmup_type (str, optional): Type of warmup ('constant' or 'linear'). | |
| warmup_cons_lr (float, optional): Learning rate for constant warmup. | |
| warmup_min_lr (float, optional): Minimum learning rate for linear warmup. | |
| """ | |
| if lr_scheduler not in AVAI_SCHEDS: | |
| raise ValueError( | |
| f"scheduler must be one of {AVAI_SCHEDS}, but got {lr_scheduler}" | |
| ) | |
| if lr_scheduler == "single_step": | |
| if isinstance(stepsize, (list, tuple)): | |
| stepsize = stepsize[-1] | |
| if not isinstance(stepsize, int): | |
| raise TypeError( | |
| "For single_step lr_scheduler, stepsize must " | |
| f"be an integer, but got {type(stepsize)}" | |
| ) | |
| if stepsize <= 0: | |
| stepsize = max_epoch | |
| scheduler = torch.optim.lr_scheduler.StepLR( | |
| optimizer, step_size=stepsize, gamma=gamma | |
| ) | |
| elif lr_scheduler == "multi_step": | |
| if not isinstance(stepsize, (list, tuple)): | |
| raise TypeError( | |
| "For multi_step lr_scheduler, stepsize must " | |
| f"be a list, but got {type(stepsize)}" | |
| ) | |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, milestones=stepsize, gamma=gamma | |
| ) | |
| elif lr_scheduler == "cosine": | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer=optimizer, T_max=float(max_epoch) | |
| ) | |
| if warmup_epoch > 0: | |
| if not warmup_recount: | |
| scheduler.last_epoch = warmup_epoch | |
| if warmup_type == "constant": | |
| scheduler = ConstantWarmupScheduler( | |
| optimizer, scheduler, warmup_epoch, warmup_cons_lr | |
| ) | |
| elif warmup_type == "linear": | |
| scheduler = LinearWarmupScheduler( | |
| optimizer, scheduler, warmup_epoch, warmup_min_lr | |
| ) | |
| else: | |
| raise ValueError | |
| return scheduler | |