| |
| from enum import Enum |
| import itertools |
| from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union |
| import torch |
|
|
| from detectron2.config import CfgNode |
|
|
| from detectron2.solver.build import maybe_add_gradient_clipping |
|
|
| def match_name_keywords(n, name_keywords): |
| out = False |
| for b in name_keywords: |
| if b in n: |
| out = True |
| break |
| return out |
|
|
| def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: |
| """ |
| Build an optimizer from config. |
| """ |
| params: List[Dict[str, Any]] = [] |
| memo: Set[torch.nn.parameter.Parameter] = set() |
| custom_multiplier_name = cfg.SOLVER.CUSTOM_MULTIPLIER_NAME |
| optimizer_type = cfg.SOLVER.OPTIMIZER |
| for key, value in model.named_parameters(recurse=True): |
| if not value.requires_grad: |
| continue |
| |
| if value in memo: |
| continue |
| memo.add(value) |
| lr = cfg.SOLVER.BASE_LR |
| weight_decay = cfg.SOLVER.WEIGHT_DECAY |
| if "backbone" in key: |
| lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER |
| if match_name_keywords(key, custom_multiplier_name): |
| lr = lr * cfg.SOLVER.CUSTOM_MULTIPLIER |
| print('Costum LR', key, lr) |
| param = {"params": [value], "lr": lr} |
| if optimizer_type != 'ADAMW': |
| param['weight_decay'] = weight_decay |
| params += [param] |
|
|
| def maybe_add_full_model_gradient_clipping(optim): |
| |
| clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE |
| enable = ( |
| cfg.SOLVER.CLIP_GRADIENTS.ENABLED |
| and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" |
| and clip_norm_val > 0.0 |
| ) |
|
|
| class FullModelGradientClippingOptimizer(optim): |
| def step(self, closure=None): |
| all_params = itertools.chain(*[x["params"] for x in self.param_groups]) |
| torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) |
| super().step(closure=closure) |
|
|
| return FullModelGradientClippingOptimizer if enable else optim |
|
|
| |
| if optimizer_type == 'SGD': |
| optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( |
| params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, |
| nesterov=cfg.SOLVER.NESTEROV |
| ) |
| elif optimizer_type == 'ADAMW': |
| optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( |
| params, cfg.SOLVER.BASE_LR, |
| weight_decay=cfg.SOLVER.WEIGHT_DECAY |
| ) |
| else: |
| raise NotImplementedError(f"no optimizer type {optimizer_type}") |
| if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": |
| optimizer = maybe_add_gradient_clipping(cfg, optimizer) |
| return optimizer |