| """ |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| from torch.optim.lr_scheduler import LRScheduler |
|
|
| from ..core import register |
|
|
|
|
| class Warmup(object): |
| def __init__(self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int=-1) -> None: |
| self.lr_scheduler = lr_scheduler |
| self.warmup_end_values = [pg['lr'] for pg in lr_scheduler.optimizer.param_groups] |
| self.last_step = last_step |
| self.warmup_duration = warmup_duration |
| self.step() |
|
|
| def state_dict(self): |
| return {k: v for k, v in self.__dict__.items() if k != 'lr_scheduler'} |
|
|
| def load_state_dict(self, state_dict): |
| self.__dict__.update(state_dict) |
|
|
| def get_warmup_factor(self, step, **kwargs): |
| raise NotImplementedError |
|
|
| def step(self, ): |
| self.last_step += 1 |
| if self.last_step >= self.warmup_duration: |
| return |
| factor = self.get_warmup_factor(self.last_step) |
| for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups): |
| pg['lr'] = factor * self.warmup_end_values[i] |
|
|
| def finished(self, ): |
| if self.last_step >= self.warmup_duration: |
| return True |
| return False |
|
|
|
|
| @register() |
| class LinearWarmup(Warmup): |
| def __init__(self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1) -> None: |
| super().__init__(lr_scheduler, warmup_duration, last_step) |
|
|
| def get_warmup_factor(self, step): |
| return min(1.0, (step + 1) / self.warmup_duration) |
|
|