EAR_challenge / utils /optimizer.py
srijandas07's picture
Upload 52 files
1c990f3 verified
import copy
import torch.optim as optim
from timm.scheduler.cosine_lr import CosineLRScheduler
import torch.distributed as dist
def is_main_process():
return dist.get_rank() == 0
def check_keywords_in_name(name, keywords=()):
isin = False
for keyword in keywords:
if keyword in name:
isin = True
return isin
def set_weight_decay(model, skip_list=(), skip_keywords=(), weight_decay=0.001, lr=2e-6, have=(), not_have=()):
has_decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(have) > 0 and not check_keywords_in_name(name, have):
continue
if len(not_have) > 0 and check_keywords_in_name(name, not_have):
continue
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
check_keywords_in_name(name, skip_keywords):
no_decay.append(param)
else:
has_decay.append(param)
return [{'params': has_decay, 'weight_decay': weight_decay, 'lr': lr},
{'params': no_decay, 'weight_decay': 0., 'lr': lr}]
def build_optimizer(config, model):
model = model.module if hasattr(model, 'module') else model
optimizer = optim.AdamW(model.parameters(), lr=config.TRAIN.LR,
weight_decay=config.TRAIN.WEIGHT_DECAY,
betas=(0.9, 0.98), eps=1e-8, )
return optimizer
def build_scheduler(config, optimizer, n_iter_per_epoch):
num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_steps,
lr_min=config.TRAIN.LR / 100,
warmup_lr_init=0,
warmup_t=warmup_steps,
cycle_limit=1,
t_in_epochs=False,
)
return lr_scheduler