| from torch import optim as optim |
|
|
|
|
| def build_optimizer(config, model): |
| """ |
| Build optimizer, set weight decay of normalization to 0 by default. |
| """ |
| skip = {} |
| skip_keywords = {} |
| if hasattr(model, "no_weight_decay"): |
| skip = model.no_weight_decay() |
| if hasattr(model, "no_weight_decay_keywords"): |
| skip_keywords = model.no_weight_decay_keywords() |
| parameters = set_weight_decay(model, skip, skip_keywords) |
|
|
| opt_lower = config.optimizer.lower() |
| optimizer = None |
|
|
|
|
| if opt_lower == "sgd": |
| optimizer = optim.SGD( |
| parameters, |
| momentum=config.momentum, |
| nesterov=True, |
| lr=config.lr, |
| weight_decay=config.weight_decay, |
| ) |
| elif opt_lower == "adamw": |
| optimizer = optim.AdamW( |
| parameters, |
| eps=config.eps, |
| betas=config.betas, |
| lr=config.lr, |
| weight_decay=config.weight_decay, |
| ) |
|
|
| return optimizer |
|
|
|
|
| def set_weight_decay(model, skip_list=(), skip_keywords=()): |
| has_decay = [] |
| no_decay = [] |
|
|
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if ( |
| len(param.shape) == 1 |
| or name.endswith(".bias") |
| or (name in skip_list) |
| or check_keywords_in_name(name, skip_keywords) |
| ): |
| no_decay.append(param) |
| |
| else: |
| has_decay.append(param) |
| return [{"params": has_decay}, {"params": no_decay, "weight_decay": 0.0}] |
|
|
|
|
| def check_keywords_in_name(name, keywords=()): |
| isin = False |
| for keyword in keywords: |
| if keyword in name: |
| isin = True |
| return isin |
|
|