|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import inspect |
|
|
from typing import List, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from mmengine.config import Config, ConfigDict |
|
|
from mmengine.device import is_npu_available, is_npu_support_full_precision |
|
|
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS |
|
|
from .optimizer_wrapper import OptimWrapper |
|
|
|
|
|
|
|
|
def register_torch_optimizers() -> List[str]: |
|
|
"""Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of registered optimizers' name. |
|
|
""" |
|
|
torch_optimizers = [] |
|
|
for module_name in dir(torch.optim): |
|
|
if module_name.startswith('__'): |
|
|
continue |
|
|
_optim = getattr(torch.optim, module_name) |
|
|
if inspect.isclass(_optim) and issubclass(_optim, |
|
|
torch.optim.Optimizer): |
|
|
OPTIMIZERS.register_module(module=_optim) |
|
|
torch_optimizers.append(module_name) |
|
|
return torch_optimizers |
|
|
|
|
|
|
|
|
TORCH_OPTIMIZERS = register_torch_optimizers() |
|
|
|
|
|
|
|
|
def register_torch_npu_optimizers() -> List[str]: |
|
|
"""Register optimizers in ``torch npu`` to the ``OPTIMIZERS`` registry. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of registered optimizers' name. |
|
|
""" |
|
|
if not is_npu_available(): |
|
|
return [] |
|
|
|
|
|
import torch_npu |
|
|
if not hasattr(torch_npu, 'optim'): |
|
|
return [] |
|
|
|
|
|
torch_npu_optimizers = [] |
|
|
for module_name in dir(torch_npu.optim): |
|
|
if module_name.startswith('__') or module_name in OPTIMIZERS: |
|
|
continue |
|
|
_optim = getattr(torch_npu.optim, module_name) |
|
|
if inspect.isclass(_optim) and issubclass(_optim, |
|
|
torch.optim.Optimizer): |
|
|
OPTIMIZERS.register_module(module=_optim) |
|
|
torch_npu_optimizers.append(module_name) |
|
|
return torch_npu_optimizers |
|
|
|
|
|
|
|
|
NPU_OPTIMIZERS = register_torch_npu_optimizers() |
|
|
|
|
|
|
|
|
def register_dadaptation_optimizers() -> List[str]: |
|
|
"""Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of registered optimizers' name. |
|
|
""" |
|
|
dadaptation_optimizers = [] |
|
|
try: |
|
|
import dadaptation |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']: |
|
|
_optim = getattr(dadaptation, module_name) |
|
|
if inspect.isclass(_optim) and issubclass(_optim, |
|
|
torch.optim.Optimizer): |
|
|
OPTIMIZERS.register_module(module=_optim) |
|
|
dadaptation_optimizers.append(module_name) |
|
|
return dadaptation_optimizers |
|
|
|
|
|
|
|
|
DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() |
|
|
|
|
|
|
|
|
def register_lion_optimizers() -> List[str]: |
|
|
"""Register Lion optimizer to the ``OPTIMIZERS`` registry. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of registered optimizers' name. |
|
|
""" |
|
|
optimizers = [] |
|
|
try: |
|
|
from lion_pytorch import Lion |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
OPTIMIZERS.register_module(module=Lion) |
|
|
optimizers.append('Lion') |
|
|
return optimizers |
|
|
|
|
|
|
|
|
LION_OPTIMIZERS = register_lion_optimizers() |
|
|
|
|
|
|
|
|
def register_sophia_optimizers() -> List[str]: |
|
|
"""Register Sophia optimizer to the ``OPTIMIZERS`` registry. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of registered optimizers' name. |
|
|
""" |
|
|
optimizers = [] |
|
|
try: |
|
|
import Sophia |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
for module_name in dir(Sophia): |
|
|
_optim = getattr(Sophia, module_name) |
|
|
if inspect.isclass(_optim) and issubclass(_optim, |
|
|
torch.optim.Optimizer): |
|
|
OPTIMIZERS.register_module(module=_optim) |
|
|
optimizers.append(module_name) |
|
|
return optimizers |
|
|
|
|
|
|
|
|
SOPHIA_OPTIMIZERS = register_sophia_optimizers() |
|
|
|
|
|
|
|
|
def register_bitsandbytes_optimizers() -> List[str]: |
|
|
"""Register optimizers in ``bitsandbytes`` to the ``OPTIMIZERS`` registry. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of registered optimizers' name. |
|
|
""" |
|
|
dadaptation_optimizers = [] |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
for module_name in [ |
|
|
'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit', |
|
|
'PagedAdamW8bit', 'LAMB8bit', 'LARS8bit', 'RMSprop8bit', |
|
|
'Lion8bit', 'PagedLion8bit', 'SGD8bit' |
|
|
]: |
|
|
_optim = getattr(bnb.optim, module_name) |
|
|
if inspect.isclass(_optim) and issubclass(_optim, |
|
|
torch.optim.Optimizer): |
|
|
OPTIMIZERS.register_module(module=_optim) |
|
|
dadaptation_optimizers.append(module_name) |
|
|
return dadaptation_optimizers |
|
|
|
|
|
|
|
|
BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers() |
|
|
|
|
|
|
|
|
def register_transformers_optimizers(): |
|
|
transformer_optimizers = [] |
|
|
try: |
|
|
from transformers import Adafactor |
|
|
except ImportError: |
|
|
pass |
|
|
else: |
|
|
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) |
|
|
transformer_optimizers.append('Adafactor') |
|
|
return transformer_optimizers |
|
|
|
|
|
|
|
|
TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers() |
|
|
|
|
|
|
|
|
def build_optim_wrapper(model: nn.Module, |
|
|
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper: |
|
|
"""Build function of OptimWrapper. |
|
|
|
|
|
If ``constructor`` is set in the ``cfg``, this method will build an |
|
|
optimizer wrapper constructor, and use optimizer wrapper constructor to |
|
|
build the optimizer wrapper. If ``constructor`` is not set, the |
|
|
``DefaultOptimWrapperConstructor`` will be used by default. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): Model to be optimized. |
|
|
cfg (dict): Config of optimizer wrapper, optimizer constructor and |
|
|
optimizer. |
|
|
|
|
|
Returns: |
|
|
OptimWrapper: The built optimizer wrapper. |
|
|
""" |
|
|
optim_wrapper_cfg = copy.deepcopy(cfg) |
|
|
constructor_type = optim_wrapper_cfg.pop('constructor', |
|
|
'DefaultOptimWrapperConstructor') |
|
|
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_npu_available() and not is_npu_support_full_precision(): |
|
|
optim_wrapper_cfg['type'] = 'AmpOptimWrapper' |
|
|
|
|
|
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( |
|
|
dict( |
|
|
type=constructor_type, |
|
|
optim_wrapper_cfg=optim_wrapper_cfg, |
|
|
paramwise_cfg=paramwise_cfg)) |
|
|
optim_wrapper = optim_wrapper_constructor(model) |
|
|
return optim_wrapper |
|
|
|