Spaces:
Sleeping
Sleeping
| from torch.optim.lr_scheduler import _LRScheduler | |
| from torch.optim.lr_scheduler import StepLR | |
| from torch.optim.lr_scheduler import MultiStepLR | |
| from torch.optim.lr_scheduler import ExponentialLR | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| class ConstantLR(_LRScheduler): | |
| def __init__(self, optimizer, last_epoch=-1): | |
| super(ConstantLR, self).__init__(optimizer, last_epoch) | |
| def get_lr(self): | |
| return [base_lr for base_lr in self.base_lrs] | |
| SCHEDULERS = { | |
| 'ConstantLR': ConstantLR, | |
| "StepLR": StepLR, | |
| "MultiStepLR": MultiStepLR, | |
| "CosineAnnealingLR": CosineAnnealingLR, | |
| "CosineAnnealingWarmRestarts": CosineAnnealingWarmRestarts, | |
| "ExponentialLR": ExponentialLR, | |
| "ReduceLROnPlateau": ReduceLROnPlateau | |
| } | |
| def get_scheduler(optimizer, kwargs): | |
| if kwargs is None: | |
| print("No lr scheduler is used.") | |
| return ConstantLR(optimizer) | |
| name = kwargs["name"] | |
| kwargs.pop("name") | |
| print("Using scheduler: '%s' with params: %s" % (name, kwargs)) | |
| return SCHEDULERS[name](optimizer, **kwargs) | |