| import inspect |
| import torch.nn as nn |
| from typing import List, Optional, Union |
| from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper |
| from mmengine.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, |
| OPTIMIZERS) |
|
|
|
|
| def add_weight_decay(model, weight_decay=1e-5, skip_list=()): |
| decay = [] |
| no_decay = [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name: |
| no_decay.append(param) |
| else: |
| decay.append(param) |
|
|
| num_decay_params = sum(p.numel() for p in decay) |
| num_nodecay_params = sum(p.numel() for p in no_decay) |
| print(f"num decayed parameter tensors: {len(decay)}, with {num_decay_params:,} parameters") |
| print(f"num non-decayed parameter tensors: {len(no_decay)}, with {num_nodecay_params:,} parameters") |
| |
| return [ |
| {'params': no_decay, 'weight_decay': 0.}, |
| {'params': decay, 'weight_decay': weight_decay}] |
|
|
|
|
| class MAROptimWrapperConstructor(DefaultOptimWrapperConstructor): |
| def __call__(self, model: nn.Module) -> OptimWrapper: |
| if hasattr(model, 'module'): |
| model = model.module |
|
|
| optim_wrapper_cfg = self.optim_wrapper_cfg.copy() |
| optim_wrapper_cfg.setdefault('type', 'OptimWrapper') |
| optimizer_cfg = self.optimizer_cfg.copy() |
| optimizer_cls = self.optimizer_cfg['type'] |
| |
| |
| |
| if isinstance(optimizer_cls, str): |
| with OPTIMIZERS.switch_scope_and_registry(None) as registry: |
| optimizer_cls = registry.get(self.optimizer_cfg['type']) |
| fisrt_arg_name = next( |
| iter(inspect.signature(optimizer_cls).parameters)) |
| |
| param_groups = add_weight_decay(model, optimizer_cfg.pop('weight_decay', 0)) |
| optimizer_cfg[fisrt_arg_name] = param_groups |
| optimizer = OPTIMIZERS.build(optimizer_cfg) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| optim_wrapper = OPTIM_WRAPPERS.build( |
| optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) |
| return optim_wrapper |
|
|