""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ from torch.optim.lr_scheduler import LRScheduler from ..core import register class Warmup(object): def __init__(self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int=-1) -> None: self.lr_scheduler = lr_scheduler self.warmup_end_values = [pg['lr'] for pg in lr_scheduler.optimizer.param_groups] self.last_step = last_step self.warmup_duration = warmup_duration self.step() def state_dict(self): return {k: v for k, v in self.__dict__.items() if k != 'lr_scheduler'} def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def get_warmup_factor(self, step, **kwargs): raise NotImplementedError def step(self, ): self.last_step += 1 if self.last_step >= self.warmup_duration: return factor = self.get_warmup_factor(self.last_step) for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups): pg['lr'] = factor * self.warmup_end_values[i] def finished(self, ): if self.last_step >= self.warmup_duration: return True return False @register() class LinearWarmup(Warmup): def __init__(self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1) -> None: super().__init__(lr_scheduler, warmup_duration, last_step) def get_warmup_factor(self, step): return min(1.0, (step + 1) / self.warmup_duration)