Spaces:
Build error
Build error
| from torch import optim as optim | |
| def build_optimizer(config, model): | |
| """ | |
| Build optimizer, set weight decay of normalization to 0 by default. | |
| """ | |
| skip = {} | |
| skip_keywords = {} | |
| if hasattr(model, 'no_weight_decay'): | |
| skip = model.no_weight_decay() | |
| if hasattr(model, 'no_weight_decay_keywords'): | |
| skip_keywords = model.no_weight_decay_keywords() | |
| parameters = set_weight_decay(model, skip, skip_keywords,config.TRAIN.BASE_LR) | |
| opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() | |
| optimizer = None | |
| if opt_lower == 'sgd': | |
| optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, | |
| lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) | |
| elif opt_lower == 'adamw': | |
| optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, | |
| lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) | |
| return optimizer | |
| # def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0): | |
| # has_decay = [] | |
| # no_decay = [] | |
| # high_lr = [] | |
| # for name, param in model.named_parameters(): | |
| # if not param.requires_grad: | |
| # continue # frozen weights | |
| # if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ | |
| # check_keywords_in_name(name, skip_keywords): | |
| # if 'meta' in name: | |
| # high_lr.append(param) | |
| # else: | |
| # no_decay.append(param) | |
| # # print(f"{name} has no weight decay") | |
| # else: | |
| # has_decay.append(param) | |
| # return [{'params': has_decay}, | |
| # # {'params':high_lr,'weight_decay': 0.,'lr':lr*10}, | |
| # {'params':high_lr,'lr':lr*20}, | |
| # {'params': no_decay, 'weight_decay': 0.}] | |
| def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0): | |
| has_decay = [] | |
| no_decay = [] | |
| for name, param in model.named_parameters(): | |
| if not param.requires_grad: | |
| continue # frozen weights | |
| if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ | |
| check_keywords_in_name(name, skip_keywords): | |
| no_decay.append(param) | |
| # print(f"{name} has no weight decay") | |
| else: | |
| has_decay.append(param) | |
| return [{'params': has_decay}, | |
| {'params': no_decay, 'weight_decay': 0.}] | |
| def check_keywords_in_name(name, keywords=()): | |
| isin = False | |
| for keyword in keywords: | |
| if keyword in name: | |
| isin = True | |
| return isin | |