| import torch |
| import torch.optim as optim |
| from catalyst.contrib.nn import Lookahead |
| from catalyst import utils |
| import math |
|
|
| class lambdax: |
| def __init__(self, cfg): |
| self.cfg = cfg |
| @staticmethod |
| def lambda_epoch(self, epoch): |
| return math.pow(1 - epoch / self.cfg.max_epoch, self.cfg.poly_exp) |
|
|
|
|
| def get_optimizer(cfg, net): |
| if cfg.lr_mode == 'multi': |
| layerwise_params = {"backbone.*": dict(lr=cfg.backbone_lr, weight_decay=cfg.backbone_weight_decay)} |
| net_params = utils.process_model_params(net, layerwise_params=layerwise_params) |
| else: |
| net_params = net.parameters() |
|
|
| if cfg.type == "AdamW": |
| optimizer = optim.AdamW(net_params, lr=cfg.lr, weight_decay=cfg.weight_decay) |
| |
| elif cfg.type == "SGD": |
| optimizer = optim.SGD(net_params, lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum, |
| nesterov=False) |
| else: |
| raise KeyError("The optimizer type ( %s ) doesn't exist!!!" % cfg.type) |
|
|
| return optimizer |
|
|
| def get_scheduler(cfg, optimizer): |
| if cfg.type == 'Poly': |
| lambda1 = lambda epoch: math.pow(1 - epoch / cfg.max_epoch, cfg.poly_exp) |
| scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) |
| elif cfg.type == 'CosineAnnealingLR': |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch, eta_min=1e-6) |
| elif cfg.type == 'linear': |
| def lambda_rule(epoch): |
| lr_l = 1.0 - epoch / float(cfg.max_epoch + 1) |
| return lr_l |
| scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) |
| elif cfg.type == 'step': |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg.step_size, gamma=cfg.gamma) |
| elif cfg.type == 'multistep': |
| scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.milestones, gamma=cfg.gamma) |
| elif cfg.type == 'reduce': |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.patience, factor=cfg.factor) |
| else: |
| raise KeyError("The scheduler type ( %s ) doesn't exist!!!" % cfg.type) |
| |
| return scheduler |
|
|
| def build_optimizer(cfg, net): |
| optimizer = get_optimizer(cfg.optimizer, net) |
| scheduler = get_scheduler(cfg.scheduler, optimizer) |
| |
| |
| |
| |
| |
| |
| |
| |
| return optimizer, scheduler |
|
|
|
|
| |
|
|
| |
| |
|
|