AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
###
# Modified by Francesco Laiti - date 23 February 2024
# Fetched from https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/optim/lr_scheduler.py
###
import torch
from torch.optim.lr_scheduler import _LRScheduler
AVAI_SCHEDS = ["single_step", "multi_step", "cosine"]
class _BaseWarmupScheduler(_LRScheduler):
def __init__(
self, optimizer, successor, warmup_epoch, last_epoch=-1, verbose=False
):
self.successor = successor
self.warmup_epoch = warmup_epoch
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if self.last_epoch >= self.warmup_epoch:
self.successor.step(epoch)
self._last_lr = self.successor.get_last_lr()
else:
super().step(epoch)
class ConstantWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self, optimizer, successor, warmup_epoch, cons_lr, last_epoch=-1, verbose=False
):
self.cons_lr = cons_lr
super().__init__(optimizer, successor, warmup_epoch, last_epoch, verbose)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
return [self.cons_lr for _ in self.base_lrs]
class LinearWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self, optimizer, successor, warmup_epoch, min_lr, last_epoch=-1, verbose=False
):
self.min_lr = min_lr
super().__init__(optimizer, successor, warmup_epoch, last_epoch, verbose)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
if self.last_epoch == 0:
return [self.min_lr for _ in self.base_lrs]
return [lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs]
def build_lr_scheduler(
optimizer,
lr_scheduler,
max_epoch,
warmup_epoch=0,
warmup_recount=False,
warmup_type=None,
warmup_cons_lr=0.01,
warmup_min_lr=0.001,
stepsize=None,
gamma=None,
):
"""
A function wrapper for building a learning rate scheduler.
Args:
optimizer (Optimizer): an Optimizer.
lr_scheduler (str): Type of learning rate scheduler.
stepsize (int or list/tuple): Step size for learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
max_epoch (int): Maximum number of epochs.
warmup_epoch (int, optional): Number of warmup epochs.
warmup_recount (bool, optional): Recount option for warmup.
warmup_type (str, optional): Type of warmup ('constant' or 'linear').
warmup_cons_lr (float, optional): Learning rate for constant warmup.
warmup_min_lr (float, optional): Minimum learning rate for linear warmup.
"""
if lr_scheduler not in AVAI_SCHEDS:
raise ValueError(
f"scheduler must be one of {AVAI_SCHEDS}, but got {lr_scheduler}"
)
if lr_scheduler == "single_step":
if isinstance(stepsize, (list, tuple)):
stepsize = stepsize[-1]
if not isinstance(stepsize, int):
raise TypeError(
"For single_step lr_scheduler, stepsize must "
f"be an integer, but got {type(stepsize)}"
)
if stepsize <= 0:
stepsize = max_epoch
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=stepsize, gamma=gamma
)
elif lr_scheduler == "multi_step":
if not isinstance(stepsize, (list, tuple)):
raise TypeError(
"For multi_step lr_scheduler, stepsize must "
f"be a list, but got {type(stepsize)}"
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=stepsize, gamma=gamma
)
elif lr_scheduler == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=float(max_epoch)
)
if warmup_epoch > 0:
if not warmup_recount:
scheduler.last_epoch = warmup_epoch
if warmup_type == "constant":
scheduler = ConstantWarmupScheduler(
optimizer, scheduler, warmup_epoch, warmup_cons_lr
)
elif warmup_type == "linear":
scheduler = LinearWarmupScheduler(
optimizer, scheduler, warmup_epoch, warmup_min_lr
)
else:
raise ValueError
return scheduler