| | """ Scheduler Factory |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | from .timm.cosine_lr import CosineLRScheduler |
| | from .timm.tanh_lr import TanhLRScheduler |
| | from .timm.step_lr import StepLRScheduler |
| | from .timm.plateau_lr import PlateauLRScheduler |
| | import torch |
| |
|
| | def create_scheduler(args, optimizer, **kwargs): |
| | num_epochs = args.epochs |
| |
|
| | if getattr(args, 'lr_noise', None) is not None: |
| | lr_noise = getattr(args, 'lr_noise') |
| | if isinstance(lr_noise, (list, tuple)): |
| | noise_range = [n * num_epochs for n in lr_noise] |
| | if len(noise_range) == 1: |
| | noise_range = noise_range[0] |
| | else: |
| | noise_range = lr_noise * num_epochs |
| | else: |
| | noise_range = None |
| |
|
| | lr_scheduler = None |
| | if args.lr_policy == 'cosine': |
| | lr_scheduler = CosineLRScheduler( |
| | optimizer, |
| | t_initial=num_epochs, |
| | t_mul=getattr(args, 'lr_cycle_mul', 1.), |
| | lr_min=args.lr_min, |
| | decay_rate=args.decay_rate, |
| | warmup_lr_init=args.warmup_lr, |
| | warmup_t=args.warmup_epochs, |
| | cycle_limit=getattr(args, 'lr_cycle_limit', 1), |
| | t_in_epochs=True, |
| | noise_range_t=noise_range, |
| | noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
| | noise_std=getattr(args, 'lr_noise_std', 1.), |
| | noise_seed=getattr(args, 'seed', 42), |
| | ) |
| | num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS |
| | elif args.lr_policy == 'tanh': |
| | lr_scheduler = TanhLRScheduler( |
| | optimizer, |
| | t_initial=num_epochs, |
| | t_mul=getattr(args, 'lr_cycle_mul', 1.), |
| | lr_min=args.min_lr, |
| | warmup_lr_init=args.warmup_lr, |
| | warmup_t=args.warmup_epochs, |
| | cycle_limit=getattr(args, 'lr_cycle_limit', 1), |
| | t_in_epochs=True, |
| | noise_range_t=noise_range, |
| | noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
| | noise_std=getattr(args, 'lr_noise_std', 1.), |
| | noise_seed=getattr(args, 'seed', 42), |
| | ) |
| | num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS |
| | elif args.lr_policy == 'step': |
| | lr_scheduler = StepLRScheduler( |
| | optimizer, |
| | decay_t=args.decay_epochs - getattr(kwargs, 'init_epoch', 0), |
| | decay_rate=args.decay_rate, |
| | warmup_lr_init=args.warmup_lr, |
| | warmup_t=args.warmup_epochs, |
| | noise_range_t=noise_range, |
| | noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
| | noise_std=getattr(args, 'lr_noise_std', 1.), |
| | noise_seed=getattr(args, 'seed', 42), |
| | ) |
| | elif args.lr_policy == 'plateau': |
| | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' |
| | lr_scheduler = PlateauLRScheduler( |
| | optimizer, |
| | decay_rate=args.decay_rate, |
| | patience_t=args.patience_epochs, |
| | lr_min=args.min_lr, |
| | mode=mode, |
| | warmup_lr_init=args.warmup_lr, |
| | warmup_t=args.warmup_epochs, |
| | cooldown_t=0, |
| | noise_range_t=noise_range, |
| | noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
| | noise_std=getattr(args, 'lr_noise_std', 1.), |
| | noise_seed=getattr(args, 'seed', 42), |
| | ) |
| | elif args.lr_policy == "onecyclelr": |
| | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | max_lr=args.LR, |
| | total_steps=kwargs["total_steps"], |
| | pct_start=args.PCT_START, |
| | div_factor=args.DIV_FACTOR_ONECOS, |
| | final_div_factor=args.FIN_DACTOR_ONCCOS, |
| | ) |
| | elif args.lr_policy == "cosinerestart": |
| | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| | optimizer, |
| | T_0 = kwargs["total_steps"], |
| | T_mult=2, |
| | eta_min = 1e-6, |
| | last_epoch=-1, |
| | ) |
| | return lr_scheduler |