Spaces:
Sleeping
Sleeping
| import copy | |
| import warnings | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch_geometric.data import Data, Batch | |
| from .warmup import GradualWarmupScheduler | |
| #customize exp lr scheduler with min lr | |
| class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR): | |
| def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False): | |
| self.gamma = gamma | |
| self.min_lr = min_lr | |
| super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| warnings.warn("To get the last learning rate computed by the scheduler, " | |
| "please use `get_last_lr()`.", UserWarning) | |
| if self.last_epoch == 0: | |
| return self.base_lrs | |
| return [max(group['lr'] * self.gamma, self.min_lr) | |
| for group in self.optimizer.param_groups] | |
| def _get_closed_form_lr(self): | |
| return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr) | |
| for base_lr in self.base_lrs] | |
| def repeat_data(data: Data, num_repeat) -> Batch: | |
| datas = [copy.deepcopy(data) for i in range(num_repeat)] | |
| return Batch.from_data_list(datas) | |
| def repeat_batch(batch: Batch, num_repeat) -> Batch: | |
| datas = batch.to_data_list() | |
| new_data = [] | |
| for i in range(num_repeat): | |
| new_data += copy.deepcopy(datas) | |
| return Batch.from_data_list(new_data) | |
| def inf_iterator(iterable): | |
| iterator = iterable.__iter__() | |
| while True: | |
| try: | |
| yield iterator.__next__() | |
| except StopIteration: | |
| iterator = iterable.__iter__() | |
| def get_optimizer(cfg, model): | |
| if cfg.type == 'adam': | |
| return torch.optim.Adam( | |
| model.parameters(), | |
| lr=cfg.lr, | |
| weight_decay=cfg.weight_decay, | |
| betas=(cfg.beta1, cfg.beta2, ) | |
| ) | |
| else: | |
| raise NotImplementedError('Optimizer not supported: %s' % cfg.type) | |
| def get_scheduler(cfg, optimizer): | |
| if cfg.type == 'plateau': | |
| return torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| factor=cfg.factor, | |
| patience=cfg.patience, | |
| min_lr=cfg.min_lr | |
| ) | |
| elif cfg.type == 'warmup_plateau': | |
| return GradualWarmupScheduler( | |
| optimizer, | |
| multiplier = cfg.multiplier, | |
| total_epoch = cfg.total_epoch, | |
| after_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| factor=cfg.factor, | |
| patience=cfg.patience, | |
| min_lr=cfg.min_lr | |
| ) | |
| ) | |
| elif cfg.type == 'expmin': | |
| return ExponentialLR_with_minLr( | |
| optimizer, | |
| gamma=cfg.factor, | |
| min_lr=cfg.min_lr, | |
| ) | |
| elif cfg.type == 'expmin_milestone': | |
| gamma = np.exp(np.log(cfg.factor) / cfg.milestone) | |
| return ExponentialLR_with_minLr( | |
| optimizer, | |
| gamma=gamma, | |
| min_lr=cfg.min_lr, | |
| ) | |
| else: | |
| raise NotImplementedError('Scheduler not supported: %s' % cfg.type) | |