| | import torch |
| | import torch.optim as optim |
| | from catalyst.contrib.nn import Lookahead |
| | from catalyst import utils |
| | import math |
| |
|
| | class lambdax: |
| | def __init__(self, cfg): |
| | self.cfg = cfg |
| | @staticmethod |
| | def lambda_epoch(self, epoch): |
| | return math.pow(1 - epoch / self.cfg.max_epoch, self.cfg.poly_exp) |
| |
|
| |
|
| | def get_optimizer(cfg, net): |
| | if cfg.lr_mode == 'multi': |
| | layerwise_params = {"backbone.*": dict(lr=cfg.backbone_lr, weight_decay=cfg.backbone_weight_decay)} |
| | net_params = utils.process_model_params(net, layerwise_params=layerwise_params) |
| | else: |
| | net_params = net.parameters() |
| |
|
| | if cfg.type == "AdamW": |
| | optimizer = optim.AdamW(net_params, lr=cfg.lr, weight_decay=cfg.weight_decay) |
| | |
| | elif cfg.type == "SGD": |
| | optimizer = optim.SGD(net_params, lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum, |
| | nesterov=False) |
| | else: |
| | raise KeyError("The optimizer type ( %s ) doesn't exist!!!" % cfg.type) |
| |
|
| | return optimizer |
| |
|
| | def get_scheduler(cfg, optimizer): |
| | if cfg.type == 'Poly': |
| | lambda1 = lambda epoch: math.pow(1 - epoch / cfg.max_epoch, cfg.poly_exp) |
| | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) |
| | elif cfg.type == 'CosineAnnealingLR': |
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch, eta_min=1e-6) |
| | elif cfg.type == 'linear': |
| | def lambda_rule(epoch): |
| | lr_l = 1.0 - epoch / float(cfg.max_epoch + 1) |
| | return lr_l |
| | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) |
| | elif cfg.type == 'step': |
| | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg.step_size, gamma=cfg.gamma) |
| | elif cfg.type == 'multistep': |
| | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.milestones, gamma=cfg.gamma) |
| | elif cfg.type == 'reduce': |
| | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=cfg.patience, factor=cfg.factor) |
| | else: |
| | raise KeyError("The scheduler type ( %s ) doesn't exist!!!" % cfg.type) |
| | |
| | return scheduler |
| |
|
| | def build_optimizer(cfg, net): |
| | optimizer = get_optimizer(cfg.optimizer, net) |
| | scheduler = get_scheduler(cfg.scheduler, optimizer) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return optimizer, scheduler |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| |
|