LibContinual / core /scheduler.py
boringKey's picture
Upload 236 files
5fee096 verified
from torch.optim import Optimizer
import math
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.step(epoch = last_epoch + 1)
self.last_epoch = last_epoch
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineSchedule(_LRScheduler):
def __init__(self, optimizer, K):
self.K = K
super().__init__(optimizer, -1)
def cosine(self, base_lr):
if self.K == 1:
return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (2-1)))
return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (self.K-1)))
def get_lr(self):
return [self.cosine(base_lr) for base_lr in self.base_lrs]
def get_last_lr(self):
return self.get_lr()
class CosineAnnealingWarmUp(_LRScheduler):
def __init__(self, optimizer, warmup_length, T_max = 0, last_epoch = -1):
self.warmup_length = warmup_length
self.T_max = T_max
self.last_epoch = last_epoch
super().__init__(optimizer, last_epoch)
def cosine_lr(self, base_lr):
return base_lr * 0.5 * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
def warmup_lr(self, base_lr):
return base_lr * (self.last_epoch + 1) / self.warmup_length
def get_lr(self):
if self.last_epoch < self.warmup_length:
return [self.warmup_lr(base_lr) for base_lr in self.base_lrs]
else:
return [self.cosine_lr(base_lr) for base_lr in self.base_lrs]
def get_last_lr(self):
assert self.T_max > 0, 'CosineAnnealingWarmUp is called with T_max <= 0, Check your code'
return self.get_lr()
class PatienceSchedule(_LRScheduler):
def __init__(self, optimizer, patience, factor):
self.factor = factor # Factor to reduce the learning rate
self.patience = patience # Number of epochs with no improvement
self.best_loss = float('inf') # Best loss seen so far
self.counter = 0 # Counter for patience
super().__init__(optimizer, -1)
def step(self, current_loss = None, **kwargs):
# Some scheduler step function is called with parameter epoch
# use kwargs to save it and don't do anything to it
if current_loss is None:
return 0
# Check if the current loss improved
if current_loss < self.best_loss:
self.best_loss = current_loss # Update the best loss
self.counter = 0 # Reset counter since we have an improvement
else:
self.counter += 1 # Increment counter if no improvement
# If patience is exhausted, reduce the learning rate
if self.counter >= self.patience:
for param_group in self.optimizer.param_groups:
param_group['lr'] /= self.factor # Reduce learning rate by the factor
print(f"Reducing learning rate to {self.optimizer.param_groups[0]['lr']:.5f}")
self.counter = 0 # Reset counter after reducing learning rate
def get_last_lr(self):
return self.optimizer.param_groups[0]['lr']