Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| import math | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| __all__ = ['AnnealingLR'] | |
| class AnnealingLR(_LRScheduler): | |
| def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): | |
| assert decay_mode in ['linear', 'cosine', 'none'] | |
| self.optimizer = optimizer | |
| self.base_lr = base_lr | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.decay_mode = decay_mode | |
| self.min_lr = min_lr | |
| self.current_step = last_step + 1 | |
| self.step(self.current_step) | |
| def get_lr(self): | |
| if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: | |
| return self.base_lr * self.current_step / self.warmup_steps | |
| else: | |
| ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
| ratio = min(1.0, max(0.0, ratio)) | |
| if self.decay_mode == 'linear': | |
| return self.base_lr * (1 - ratio) | |
| elif self.decay_mode == 'cosine': | |
| return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 | |
| else: | |
| return self.base_lr | |
| def step(self, current_step=None): | |
| if current_step is None: | |
| current_step = self.current_step + 1 | |
| self.current_step = current_step | |
| new_lr = max(self.min_lr, self.get_lr()) | |
| if isinstance(self.optimizer, list): | |
| for o in self.optimizer: | |
| for group in o.param_groups: | |
| group['lr'] = new_lr | |
| else: | |
| for group in self.optimizer.param_groups: | |
| group['lr'] = new_lr | |
| def state_dict(self): | |
| return { | |
| 'base_lr': self.base_lr, | |
| 'warmup_steps': self.warmup_steps, | |
| 'total_steps': self.total_steps, | |
| 'decay_mode': self.decay_mode, | |
| 'current_step': self.current_step} | |
| def load_state_dict(self, state_dict): | |
| self.base_lr = state_dict['base_lr'] | |
| self.warmup_steps = state_dict['warmup_steps'] | |
| self.total_steps = state_dict['total_steps'] | |
| self.decay_mode = state_dict['decay_mode'] | |
| self.current_step = state_dict['current_step'] | |