| |
| import copy |
| import warnings |
| from abc import ABCMeta |
| from collections import defaultdict |
| from logging import FileHandler |
|
|
| import torch.nn as nn |
|
|
| from annotator.uniformer.mmcv.runner.dist_utils import master_only |
| from annotator.uniformer.mmcv.utils.logging import get_logger, logger_initialized, print_log |
|
|
|
|
| class BaseModule(nn.Module, metaclass=ABCMeta): |
| """Base module for all modules in openmmlab. |
| |
| ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional |
| functionality of parameter initialization. Compared with |
| ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes. |
| |
| - ``init_cfg``: the config to control the initialization. |
| - ``init_weights``: The function of parameter |
| initialization and recording initialization |
| information. |
| - ``_params_init_info``: Used to track the parameter |
| initialization information. This attribute only |
| exists during executing the ``init_weights``. |
| |
| Args: |
| init_cfg (dict, optional): Initialization config dict. |
| """ |
|
|
| def __init__(self, init_cfg=None): |
| """Initialize BaseModule, inherited from `torch.nn.Module`""" |
|
|
| |
| |
|
|
| super(BaseModule, self).__init__() |
| |
| |
| self._is_init = False |
|
|
| self.init_cfg = copy.deepcopy(init_cfg) |
|
|
| |
| |
| |
| |
| |
|
|
| @property |
| def is_init(self): |
| return self._is_init |
|
|
| def init_weights(self): |
| """Initialize the weights.""" |
|
|
| is_top_level_module = False |
| |
| if not hasattr(self, '_params_init_info'): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self._params_init_info = defaultdict(dict) |
| is_top_level_module = True |
|
|
| |
| |
| |
| |
| for name, param in self.named_parameters(): |
| self._params_init_info[param][ |
| 'init_info'] = f'The value is the same before and ' \ |
| f'after calling `init_weights` ' \ |
| f'of {self.__class__.__name__} ' |
| self._params_init_info[param][ |
| 'tmp_mean_value'] = param.data.mean() |
|
|
| |
| |
| |
| |
| for sub_module in self.modules(): |
| sub_module._params_init_info = self._params_init_info |
|
|
| |
| |
| logger_names = list(logger_initialized.keys()) |
| logger_name = logger_names[0] if logger_names else 'mmcv' |
|
|
| from ..cnn import initialize |
| from ..cnn.utils.weight_init import update_init_info |
| module_name = self.__class__.__name__ |
| if not self._is_init: |
| if self.init_cfg: |
| print_log( |
| f'initialize {module_name} with init_cfg {self.init_cfg}', |
| logger=logger_name) |
| initialize(self, self.init_cfg) |
| if isinstance(self.init_cfg, dict): |
| |
| |
| |
| |
| if self.init_cfg['type'] == 'Pretrained': |
| return |
|
|
| for m in self.children(): |
| if hasattr(m, 'init_weights'): |
| m.init_weights() |
| |
| update_init_info( |
| m, |
| init_info=f'Initialized by ' |
| f'user-defined `init_weights`' |
| f' in {m.__class__.__name__} ') |
|
|
| self._is_init = True |
| else: |
| warnings.warn(f'init_weights of {self.__class__.__name__} has ' |
| f'been called more than once.') |
|
|
| if is_top_level_module: |
| self._dump_init_info(logger_name) |
|
|
| for sub_module in self.modules(): |
| del sub_module._params_init_info |
|
|
| @master_only |
| def _dump_init_info(self, logger_name): |
| """Dump the initialization information to a file named |
| `initialization.log.json` in workdir. |
| |
| Args: |
| logger_name (str): The name of logger. |
| """ |
|
|
| logger = get_logger(logger_name) |
|
|
| with_file_handler = False |
| |
| for handler in logger.handlers: |
| if isinstance(handler, FileHandler): |
| handler.stream.write( |
| 'Name of parameter - Initialization information\n') |
| for name, param in self.named_parameters(): |
| handler.stream.write( |
| f'\n{name} - {param.shape}: ' |
| f"\n{self._params_init_info[param]['init_info']} \n") |
| handler.stream.flush() |
| with_file_handler = True |
| if not with_file_handler: |
| for name, param in self.named_parameters(): |
| print_log( |
| f'\n{name} - {param.shape}: ' |
| f"\n{self._params_init_info[param]['init_info']} \n ", |
| logger=logger_name) |
|
|
| def __repr__(self): |
| s = super().__repr__() |
| if self.init_cfg: |
| s += f'\ninit_cfg={self.init_cfg}' |
| return s |
|
|
|
|
| class Sequential(BaseModule, nn.Sequential): |
| """Sequential module in openmmlab. |
| |
| Args: |
| init_cfg (dict, optional): Initialization config dict. |
| """ |
|
|
| def __init__(self, *args, init_cfg=None): |
| BaseModule.__init__(self, init_cfg) |
| nn.Sequential.__init__(self, *args) |
|
|
|
|
| class ModuleList(BaseModule, nn.ModuleList): |
| """ModuleList in openmmlab. |
| |
| Args: |
| modules (iterable, optional): an iterable of modules to add. |
| init_cfg (dict, optional): Initialization config dict. |
| """ |
|
|
| def __init__(self, modules=None, init_cfg=None): |
| BaseModule.__init__(self, init_cfg) |
| nn.ModuleList.__init__(self, modules) |
|
|