| | |
| | import copy |
| | import warnings |
| | from abc import ABCMeta |
| | from collections import defaultdict |
| | from logging import FileHandler |
| |
|
| | import torch.nn as nn |
| |
|
| | from annotator.uniformer.mmcv.runner.dist_utils import master_only |
| | from annotator.uniformer.mmcv.utils.logging import get_logger, logger_initialized, print_log |
| |
|
| |
|
| | class BaseModule(nn.Module, metaclass=ABCMeta): |
| | """Base module for all modules in openmmlab. |
| | |
| | ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional |
| | functionality of parameter initialization. Compared with |
| | ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. |
| | |
| | - ``init_cfg``: the config to control the initialization. |
| | - ``init_weights``: The function of parameter |
| | initialization and recording initialization |
| | information. |
| | - ``_params_init_info``: Used to track the parameter |
| | initialization information. This attribute only |
| | exists during executing the ``init_weights``. |
| | |
| | Args: |
| | init_cfg (dict, optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, init_cfg=None): |
| | """Initialize BaseModule, inherited from `torch.nn.Module`""" |
| |
|
| | |
| | |
| |
|
| | super(BaseModule, self).__init__() |
| | |
| | |
| | self._is_init = False |
| |
|
| | self.init_cfg = copy.deepcopy(init_cfg) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | @property |
| | def is_init(self): |
| | return self._is_init |
| |
|
| | def init_weights(self): |
| | """Initialize the weights.""" |
| |
|
| | is_top_level_module = False |
| | |
| | if not hasattr(self, '_params_init_info'): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self._params_init_info = defaultdict(dict) |
| | is_top_level_module = True |
| |
|
| | |
| | |
| | |
| | |
| | for name, param in self.named_parameters(): |
| | self._params_init_info[param][ |
| | 'init_info'] = f'The value is the same before and ' \ |
| | f'after calling `init_weights` ' \ |
| | f'of {self.__class__.__name__} ' |
| | self._params_init_info[param][ |
| | 'tmp_mean_value'] = param.data.mean() |
| |
|
| | |
| | |
| | |
| | |
| | for sub_module in self.modules(): |
| | sub_module._params_init_info = self._params_init_info |
| |
|
| | |
| | |
| | logger_names = list(logger_initialized.keys()) |
| | logger_name = logger_names[0] if logger_names else 'mmcv' |
| |
|
| | from ..cnn import initialize |
| | from ..cnn.utils.weight_init import update_init_info |
| | module_name = self.__class__.__name__ |
| | if not self._is_init: |
| | if self.init_cfg: |
| | print_log( |
| | f'initialize {module_name} with init_cfg {self.init_cfg}', |
| | logger=logger_name) |
| | initialize(self, self.init_cfg) |
| | if isinstance(self.init_cfg, dict): |
| | |
| | |
| | |
| | |
| | if self.init_cfg['type'] == 'Pretrained': |
| | return |
| |
|
| | for m in self.children(): |
| | if hasattr(m, 'init_weights'): |
| | m.init_weights() |
| | |
| | update_init_info( |
| | m, |
| | init_info=f'Initialized by ' |
| | f'user-defined `init_weights`' |
| | f' in {m.__class__.__name__} ') |
| |
|
| | self._is_init = True |
| | else: |
| | warnings.warn(f'init_weights of {self.__class__.__name__} has ' |
| | f'been called more than once.') |
| |
|
| | if is_top_level_module: |
| | self._dump_init_info(logger_name) |
| |
|
| | for sub_module in self.modules(): |
| | del sub_module._params_init_info |
| |
|
| | @master_only |
| | def _dump_init_info(self, logger_name): |
| | """Dump the initialization information to a file named |
| | `initialization.log.json` in workdir. |
| | |
| | Args: |
| | logger_name (str): The name of logger. |
| | """ |
| |
|
| | logger = get_logger(logger_name) |
| |
|
| | with_file_handler = False |
| | |
| | for handler in logger.handlers: |
| | if isinstance(handler, FileHandler): |
| | handler.stream.write( |
| | 'Name of parameter - Initialization information\n') |
| | for name, param in self.named_parameters(): |
| | handler.stream.write( |
| | f'\n{name} - {param.shape}: ' |
| | f"\n{self._params_init_info[param]['init_info']} \n") |
| | handler.stream.flush() |
| | with_file_handler = True |
| | if not with_file_handler: |
| | for name, param in self.named_parameters(): |
| | print_log( |
| | f'\n{name} - {param.shape}: ' |
| | f"\n{self._params_init_info[param]['init_info']} \n ", |
| | logger=logger_name) |
| |
|
| | def __repr__(self): |
| | s = super().__repr__() |
| | if self.init_cfg: |
| | s += f'\ninit_cfg={self.init_cfg}' |
| | return s |
| |
|
| |
|
| | class Sequential(BaseModule, nn.Sequential): |
| | """Sequential module in openmmlab. |
| | |
| | Args: |
| | init_cfg (dict, optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, *args, init_cfg=None): |
| | BaseModule.__init__(self, init_cfg) |
| | nn.Sequential.__init__(self, *args) |
| |
|
| |
|
| | class ModuleList(BaseModule, nn.ModuleList): |
| | """ModuleList in openmmlab. |
| | |
| | Args: |
| | modules (iterable, optional): an iterable of modules to add. |
| | init_cfg (dict, optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, modules=None, init_cfg=None): |
| | BaseModule.__init__(self, init_cfg) |
| | nn.ModuleList.__init__(self, modules) |
| |
|