|
|
import copy
|
|
|
import torch.optim as optim
|
|
|
from timm.scheduler.cosine_lr import CosineLRScheduler
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
def is_main_process():
|
|
|
return dist.get_rank() == 0
|
|
|
|
|
|
|
|
|
def check_keywords_in_name(name, keywords=()):
|
|
|
isin = False
|
|
|
for keyword in keywords:
|
|
|
if keyword in name:
|
|
|
isin = True
|
|
|
return isin
|
|
|
|
|
|
|
|
|
def set_weight_decay(model, skip_list=(), skip_keywords=(), weight_decay=0.001, lr=2e-6, have=(), not_have=()):
|
|
|
has_decay = []
|
|
|
no_decay = []
|
|
|
for name, param in model.named_parameters():
|
|
|
if not param.requires_grad:
|
|
|
continue
|
|
|
if len(have) > 0 and not check_keywords_in_name(name, have):
|
|
|
continue
|
|
|
if len(not_have) > 0 and check_keywords_in_name(name, not_have):
|
|
|
continue
|
|
|
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
|
|
|
check_keywords_in_name(name, skip_keywords):
|
|
|
no_decay.append(param)
|
|
|
else:
|
|
|
has_decay.append(param)
|
|
|
|
|
|
return [{'params': has_decay, 'weight_decay': weight_decay, 'lr': lr},
|
|
|
{'params': no_decay, 'weight_decay': 0., 'lr': lr}]
|
|
|
|
|
|
|
|
|
def build_optimizer(config, model):
|
|
|
model = model.module if hasattr(model, 'module') else model
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=config.TRAIN.LR,
|
|
|
weight_decay=config.TRAIN.WEIGHT_DECAY,
|
|
|
betas=(0.9, 0.98), eps=1e-8, )
|
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
def build_scheduler(config, optimizer, n_iter_per_epoch):
|
|
|
num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
|
|
|
warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
|
|
|
|
|
|
lr_scheduler = CosineLRScheduler(
|
|
|
optimizer,
|
|
|
t_initial=num_steps,
|
|
|
lr_min=config.TRAIN.LR / 100,
|
|
|
warmup_lr_init=0,
|
|
|
warmup_t=warmup_steps,
|
|
|
cycle_limit=1,
|
|
|
t_in_epochs=False,
|
|
|
)
|
|
|
|
|
|
return lr_scheduler
|
|
|
|