|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import logging |
|
|
from abc import ABCMeta |
|
|
from collections import defaultdict |
|
|
from logging import FileHandler |
|
|
from typing import Iterable, List, Optional, Union |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
from mmengine.dist import master_only |
|
|
from mmengine.logging import MMLogger, print_log |
|
|
from .weight_init import PretrainedInit, initialize, update_init_info |
|
|
from .wrappers.utils import is_model_wrapper |
|
|
|
|
|
|
|
|
class BaseModule(nn.Module, metaclass=ABCMeta): |
|
|
"""Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of |
|
|
``torch.nn.Module`` with additional functionality of parameter |
|
|
initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly |
|
|
adds three attributes. |
|
|
|
|
|
- ``init_cfg``: the config to control the initialization. |
|
|
- ``init_weights``: The function of parameter initialization and recording |
|
|
initialization information. |
|
|
- ``_params_init_info``: Used to track the parameter initialization |
|
|
information. This attribute only exists during executing the |
|
|
``init_weights``. |
|
|
|
|
|
Note: |
|
|
:obj:`PretrainedInit` has a higher priority than any other |
|
|
initializer. The loaded pretrained weights will overwrite |
|
|
the previous initialized weights. |
|
|
|
|
|
Args: |
|
|
init_cfg (dict or List[dict], optional): Initialization config dict. |
|
|
""" |
|
|
|
|
|
def __init__(self, init_cfg: Union[dict, List[dict], None] = None): |
|
|
"""Initialize BaseModule, inherited from `torch.nn.Module`""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self._is_init = False |
|
|
|
|
|
self.init_cfg = copy.deepcopy(init_cfg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def is_init(self): |
|
|
return self._is_init |
|
|
|
|
|
@is_init.setter |
|
|
def is_init(self, value): |
|
|
self._is_init = value |
|
|
|
|
|
def init_weights(self): |
|
|
"""Initialize the weights.""" |
|
|
|
|
|
is_top_level_module = False |
|
|
|
|
|
if not hasattr(self, '_params_init_info'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._params_init_info = defaultdict(dict) |
|
|
is_top_level_module = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for name, param in self.named_parameters(): |
|
|
self._params_init_info[param][ |
|
|
'init_info'] = f'The value is the same before and ' \ |
|
|
f'after calling `init_weights` ' \ |
|
|
f'of {self.__class__.__name__} ' |
|
|
self._params_init_info[param][ |
|
|
'tmp_mean_value'] = param.data.mean().cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for sub_module in self.modules(): |
|
|
sub_module._params_init_info = self._params_init_info |
|
|
|
|
|
module_name = self.__class__.__name__ |
|
|
if not self._is_init: |
|
|
if self.init_cfg: |
|
|
print_log( |
|
|
f'initialize {module_name} with init_cfg {self.init_cfg}', |
|
|
logger='current', |
|
|
level=logging.DEBUG) |
|
|
|
|
|
init_cfgs = self.init_cfg |
|
|
if isinstance(self.init_cfg, dict): |
|
|
init_cfgs = [self.init_cfg] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
other_cfgs = [] |
|
|
pretrained_cfg = [] |
|
|
for init_cfg in init_cfgs: |
|
|
assert isinstance(init_cfg, dict) |
|
|
if (init_cfg['type'] == 'Pretrained' |
|
|
or init_cfg['type'] is PretrainedInit): |
|
|
pretrained_cfg.append(init_cfg) |
|
|
else: |
|
|
other_cfgs.append(init_cfg) |
|
|
|
|
|
initialize(self, other_cfgs) |
|
|
|
|
|
for m in self.children(): |
|
|
if is_model_wrapper(m) and not hasattr(m, 'init_weights'): |
|
|
m = m.module |
|
|
if hasattr(m, 'init_weights') and not getattr( |
|
|
m, 'is_init', False): |
|
|
m.init_weights() |
|
|
|
|
|
update_init_info( |
|
|
m, |
|
|
init_info=f'Initialized by ' |
|
|
f'user-defined `init_weights`' |
|
|
f' in {m.__class__.__name__} ') |
|
|
if self.init_cfg and pretrained_cfg: |
|
|
initialize(self, pretrained_cfg) |
|
|
self._is_init = True |
|
|
else: |
|
|
print_log( |
|
|
f'init_weights of {self.__class__.__name__} has ' |
|
|
f'been called more than once.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
|
|
|
if is_top_level_module: |
|
|
self._dump_init_info() |
|
|
|
|
|
for sub_module in self.modules(): |
|
|
del sub_module._params_init_info |
|
|
|
|
|
@master_only |
|
|
def _dump_init_info(self): |
|
|
"""Dump the initialization information to a file named |
|
|
`initialization.log.json` in workdir.""" |
|
|
|
|
|
logger = MMLogger.get_current_instance() |
|
|
with_file_handler = False |
|
|
|
|
|
for handler in logger.handlers: |
|
|
if isinstance(handler, FileHandler): |
|
|
handler.stream.write( |
|
|
'Name of parameter - Initialization information\n') |
|
|
for name, param in self.named_parameters(): |
|
|
handler.stream.write( |
|
|
f'\n{name} - {param.shape}: ' |
|
|
f"\n{self._params_init_info[param]['init_info']} \n") |
|
|
handler.stream.flush() |
|
|
with_file_handler = True |
|
|
if not with_file_handler: |
|
|
for name, param in self.named_parameters(): |
|
|
logger.info( |
|
|
f'\n{name} - {param.shape}: ' |
|
|
f"\n{self._params_init_info[param]['init_info']} \n ") |
|
|
|
|
|
def __repr__(self): |
|
|
s = super().__repr__() |
|
|
if self.init_cfg: |
|
|
s += f'\ninit_cfg={self.init_cfg}' |
|
|
return s |
|
|
|
|
|
|
|
|
class Sequential(BaseModule, nn.Sequential): |
|
|
"""Sequential module in openmmlab. |
|
|
|
|
|
Ensures that all modules in ``Sequential`` have a different initialization |
|
|
strategy than the outer model |
|
|
|
|
|
Args: |
|
|
init_cfg (dict, optional): Initialization config dict. |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, init_cfg: Optional[dict] = None): |
|
|
BaseModule.__init__(self, init_cfg) |
|
|
nn.Sequential.__init__(self, *args) |
|
|
|
|
|
|
|
|
class ModuleList(BaseModule, nn.ModuleList): |
|
|
"""ModuleList in openmmlab. |
|
|
|
|
|
Ensures that all modules in ``ModuleList`` have a different initialization |
|
|
strategy than the outer model |
|
|
|
|
|
Args: |
|
|
modules (iterable, optional): An iterable of modules to add. |
|
|
init_cfg (dict, optional): Initialization config dict. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
modules: Optional[Iterable] = None, |
|
|
init_cfg: Optional[dict] = None): |
|
|
BaseModule.__init__(self, init_cfg) |
|
|
nn.ModuleList.__init__(self, modules) |
|
|
|
|
|
|
|
|
class ModuleDict(BaseModule, nn.ModuleDict): |
|
|
"""ModuleDict in openmmlab. |
|
|
|
|
|
Ensures that all modules in ``ModuleDict`` have a different initialization |
|
|
strategy than the outer model |
|
|
|
|
|
Args: |
|
|
modules (dict, optional): A mapping (dictionary) of (string: module) |
|
|
or an iterable of key-value pairs of type (string, module). |
|
|
init_cfg (dict, optional): Initialization config dict. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
modules: Optional[dict] = None, |
|
|
init_cfg: Optional[dict] = None): |
|
|
BaseModule.__init__(self, init_cfg) |
|
|
nn.ModuleDict.__init__(self, modules) |
|
|
|